def build_model(args, corpus): criterion = None ntokens = len(corpus.dictionary) model = RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout, args.dropouth, args.dropouti, args.dropoute, args.wdrop, args.tied) ### if args.resume: logging.info('Resuming model ...') model, criterion, optimizer = model_load(args.resume_path) optimizer.param_groups[0]['lr'] = args.lr model.dropouti, model.dropouth, model.dropout, args.dropoute = args.dropouti, args.dropouth, args.dropout, args.dropoute if args.wdrop: from weight_drop import WeightDrop for rnn in model.rnns: if type(rnn) == WeightDrop: rnn.dropout = args.wdrop elif rnn.zoneout > 0: rnn.zoneout = args.wdrop ### if not criterion: splits = [] if ntokens > 500000: # One Billion # This produces fairly even matrix mults for the buckets: # 0: 11723136, 1: 10854630, 2: 11270961, 3: 11219422 splits = [4200, 35000, 180000] elif ntokens > 75000: # WikiText-103 splits = [2800, 20000, 76000] logging.info(f'Using {splits}') criterion = SplitCrossEntropyLoss(args.emsize, splits=splits, verbose=False) ### params = list(model.parameters()) + list(criterion.parameters()) total_params = sum(x.size()[0] * x.size()[1] if len(x.size()) > 1 else x.size()[0] for x in params if x.size()) logging.info(f'Args: {args}') logging.info(f'Model total parameters: {total_params}') if args.cuda: model = model.cuda() criterion = criterion.cuda() return model, criterion
def run(args): np.random.seed(args.seed) torch.manual_seed(args.seed) if torch.cuda.is_available(): if not args.cuda: print("WARNING: You have a CUDA device, so you should probably run with --cuda") else: torch.cuda.manual_seed(args.seed) ############################################################################### # Load data ############################################################################### def model_save(fn): with open(fn, 'wb') as f: torch.save([model, optimizer], f) def model_load(fn): global model, criterion, optimizer with open(fn, 'rb') as f: model, optimizer = torch.load(f) import os import hashlib fn = 'corpus.{}.data'.format(hashlib.md5(args.data.encode()).hexdigest()) if os.path.exists(fn): print('Loading cached dataset...') corpus = torch.load(fn) else: print('Producing dataset...') corpus = data.Corpus(args.data) torch.save(corpus, fn) # get token frequencies and eos_tokens frequencies, eos_tokens = None, None if not args.uni_freq: frequencies = corpus.frequencies if args.reinit_h: eos_tokens = corpus.reset_idxs # batchify eval_batch_size = 1 test_batch_size = 1 print(corpus.dictionary) if args.reinit_h: ntokens = len(corpus.dictionary) + 1 if args.batch_size > 1 else len(corpus.dictionary) train_data, seq_lens = batchify_padded(corpus.train, args.batch_size, args, ntokens, eos_tokens) else: ntokens = len(corpus.dictionary) train_data = batchify(corpus.train, args.batch_size, args) val_data = batchify(corpus.valid, eval_batch_size, args) test_data = batchify(corpus.test, test_batch_size, args) ############################################################################### # Build the model ############################################################################### model = RNNModel(ntokens, args.emsize, args.nhid, args.dropout, args.dropouth, args.dropouti, args.dropoute, args.wdrop, args.nsamples, args.temperature, frequencies, args.no_bias, args.bias_reg, args.dist_fn, args.activation_fn) ### if args.resume: print('Resuming model ...') model_load(args.resume) optimizer.param_groups[0]['lr'] = args.lr model.dropouti, model.dropouth, model.dropout, args.dropoute = args.dropouti, args.dropouth, args.dropout, args.dropoute ### if args.cuda: model = model.cuda() ### params = list(model.parameters()) total_params = sum(x.size()[0] * x.size()[1] if len(x.size()) > 1 else x.size()[0] for x in params if x.size()) print('Args:', args) print('Model total parameters:', total_params) ############################################################################### # Training code ############################################################################### def evaluate(data_source, epoch, batch_size=1): # Turn on evaluation mode which disables dropout. model.eval() if args.dump_hiddens: loss, entropy, hiddens = model.evaluate(data_source, eos_tokens, args.dump_hiddens) dump_hiddens(hiddens, 'hiddens_' + str(epoch)) else: loss, entropy = model.evaluate(data_source, eos_tokens) if args.dump_words: dump_words(model.encoder.weight.detach().cpu().numpy(), 'words_' + str(epoch)) if not args.dump_entropy is None: dump(entropy, args.dump_entropy + str(epoch)) return loss def train(): # Turn on training mode which enables dropout. total_loss, avrg_loss = 0, 0 start_time = time.time() ntokens = len(corpus.dictionary) batch, i = 0, 0 hidden = model.init_hidden(args.batch_size) while i < train_data.size(0)-1: if args.reinit_h: seq_len = seq_lens[batch] - 1 else: bptt = args.bptt if np.random.random() < 0.95 else args.bptt / 2. # Prevent excessively small or negative sequence lengths seq_len = max(5, int(np.random.normal(bptt, 5))) # prevent negative sequence lengths # There's a very small chance that it could select a very long sequence length resulting in OOM # seq_len = min(seq_len, args.bptt + 10) lr2 = optimizer.param_groups[0]['lr'] optimizer.param_groups[0]['lr'] = lr2 * seq_len / args.bptt model.train() data = get_batch(train_data, i, args, seq_len=seq_len) # Starting each batch, we detach the hidden state from how it was previously produced. # If we didn't, the model would try backpropagating all the way to start of the dataset. reset_hidden = args.reinit_h if reset_hidden: hidden = model.init_hidden(args.batch_size) hidden = repackage_hidden(hidden) optimizer.zero_grad() #raw_loss = model.train_crossentropy(data, eos_tokens) raw_loss, hidden = model(data, hidden) loss = raw_loss ''' See what we can do here! We don't need the regularization as it is implicit! # Activiation Regularization if args.alpha: loss = loss + sum(args.alpha * dropped_rnn_h.pow(2).mean() for dropped_rnn_h in dropped_rnn_hs[-1:]) # Temporal Activation Regularization (slowness) if args.beta: loss = loss + sum(args.beta * (rnn_h[1:] - rnn_h[:-1]).pow(2).mean() for rnn_h in rnn_hs[-1:]) ''' loss.backward() # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. if args.clip: torch.nn.utils.clip_grad_norm_(params, args.clip) optimizer.step() total_loss += loss.data optimizer.param_groups[0]['lr'] = lr2 if batch % args.log_interval == 0 and batch > 0: cur_loss = total_loss.item() / args.log_interval elapsed = time.time() - start_time print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | ms/batch {:5.2f} | ' 'loss {:5.2f} | ppl {:8.2f} | bpc {:8.3f}'.format( epoch, batch, len(train_data) // args.bptt, optimizer.param_groups[0]['lr'], elapsed * 1000 / args.log_interval, cur_loss, cur_loss, cur_loss / math.log(2))) avrg_loss = avrg_loss + total_loss total_loss = 0 start_time = time.time() ### batch += 1 i += seq_len + 1 return avrg_loss / train_data.size(0) # Loop over epochs. lr = args.lr best_val_loss = [] valid_loss = [] stored_loss = 100000000 # At any point you can hit Ctrl + C to break out of training early. try: optimizer = None # Ensure the optimizer is optimizing params, which includes both the model's weights as well as the criterion's weight (i.e. Adaptive Softmax) if args.optimizer == 'sgd': optimizer = torch.optim.SGD(params, lr=args.lr, weight_decay=args.wdecay) if args.optimizer == 'adam': optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.wdecay) for epoch in range(1, args.epochs+1): epoch_start_time = time.time() train_loss = train() _, s, _= np.linalg.svd(model.rnn.module.weight_hh_l0.cpu().detach().numpy()) print(s[0]) #dump(model.decoder.bias.cpu().detach().numpy(), 'bias_' + str(epoch) +'.out') # skip to beginning if not in evaluation mode if epoch % args.evaluate_every > 0: print('-' * 89) print('| end of epoch {:3d} | time: {:5.2f}s | train loss {:5.2f} |'.format( epoch, (time.time() - epoch_start_time), train_loss)) print('-' * 89) continue # evaluate validation loss if 't0' in optimizer.param_groups[0]: tmp = {} for prm in model.parameters(): #if 'ax' in optimizer.state[prm]: tmp[prm] = prm.data.clone() if 'ax' in optimizer.state[prm]: prm.data = optimizer.state[prm]['ax'].clone() val_loss2 = evaluate(val_data, epoch) valid_loss.append(val_loss2) print('-' * 89) print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ' 'valid ppl {:8.2f} | valid bpc {:8.3f}'.format( epoch, (time.time() - epoch_start_time), val_loss2, math.exp(val_loss2), val_loss2 / math.log(2))) print('-' * 89) if val_loss2 < stored_loss: model_save(args.save) print('Saving Averaged!') stored_loss = val_loss2 for prm in model.parameters(): prm.data = tmp[prm].clone() else: val_loss = evaluate(val_data, epoch, eval_batch_size) valid_loss.append(val_loss) print('-' * 89) print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ' 'valid ppl {:8.2f} | valid bpc {:8.3f}'.format( epoch, (time.time() - epoch_start_time), val_loss, math.exp(val_loss), val_loss / math.log(2))) print('-' * 89) if val_loss < stored_loss: model_save(args.save) print('Saving model (new best validation)') stored_loss = val_loss if args.optimizer == 'sgd' and 't0' not in optimizer.param_groups[0] and (len(best_val_loss)>args.nonmono and val_loss > min(best_val_loss[:-args.nonmono])): print('Switching to ASGD') optimizer = torch.optim.ASGD(model.parameters(), lr=args.lr, t0=0, lambd=0., weight_decay=args.wdecay) if epoch in args.when: print('Saving model before learning rate decreased') model_save('{}.e{}'.format(args.save, epoch)) print('Dividing learning rate by 10') optimizer.param_groups[0]['lr'] /= 10. best_val_loss.append(val_loss) except KeyboardInterrupt: print('-' * 89) print('Exiting from training early') # Load the best saved model. model_load(args.save) # Run on test data. test_loss = evaluate(test_data, args.epochs+1, test_batch_size) print('=' * 89) print('| End of training | test loss {:5.2f} | test ppl {:8.2f} | test bpc {:8.3f}'.format( test_loss, math.exp(test_loss), test_loss / math.log(2))) print('=' * 89) return np.array(valid_loss), test_loss