def sample(model: CharRNN, char2int: dict, prime='The', num_chars=1000, top_k=5): """ Given a network and a char2int map, predict the next 1000 characters """ device = next(model.parameters()).device.type int2char = {ii: ch for ch, ii in char2int.items()} # set our model to evaluation mode, we use dropout after all model.eval() # First off, run through the prime characters chars = [char2int[ch] for ch in prime] h = model.init_hidden(1, device) for ch in chars: char, h = predict(model, ch, h, top_k, device) chars.append(char) # Now pass in the previous character and get a new one for ii in range(num_chars): char, h = predict(model, chars[-1], h, top_k, device) chars.append(char) return ''.join(int2char[c] for c in chars)
def train(args, model: CharRNN, step, epoch, corpus, char_to_id, criterion, model_file): optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) batch_chars = args.window_size * args.batch_size save = lambda ep: torch.save( { 'state': model.state_dict(), 'epoch': ep, 'step': step, }, str(model_file)) log = Path(args.root).joinpath('train.log').open('at', encoding='utf8') for epoch in range(epoch, args.n_epochs + 1): try: losses = [] n_iter = args.epoch_batches or (len(corpus) // batch_chars) report_each = min(10, n_iter - 1) tr = tqdm.tqdm(total=n_iter * batch_chars) tr.set_description('Epoch {}'.format(epoch)) model.train() for i in range(n_iter): inputs, targets = random_batch( corpus, batch_size=args.batch_size, window_size=args.window_size, char_to_id=char_to_id, ) loss = train_model(model, criterion, optimizer, inputs, targets) step += 1 losses.append(loss) tr.update(batch_chars) mean_loss = np.mean(losses[-report_each:]) tr.set_postfix(loss=mean_loss) if i and i % report_each == 0: write_event(log, step, loss=mean_loss) tr.close() save(ep=epoch + 1) except KeyboardInterrupt: print('\nGot Ctrl+C, saving checkpoint...') save(ep=epoch) print('done.') return if args.valid_corpus: valid_result = validate(args, model, criterion, char_to_id) write_event(log, step, **valid_result) print('Done training for {} epochs'.format(args.n_epochs))
def count_parameters(model: CharRNN): """ counts the total number of parameters in a model """ return sum(p.numel() for p in model.parameters() if p.requires_grad)