def generate(_run): config = argparse.Namespace(**_run.config) # Load trained model and vobulary mappings checkpoint = torch.load(config.model_path, map_location='cpu') model = TextGenerationModel(**checkpoint['hparams']) model.load_state_dict(checkpoint['state_dict']) ix_to_char = checkpoint['ix_to_char'] char_to_ix = {v: k for k, v in ix_to_char.items()} # Prepare initial sequence x0 = torch.tensor([char_to_ix[char] for char in config.initial_seq], dtype=torch.long).view(-1, 1) # Generate samples = model.sample(x0, config.sample_length, config.temperature).detach().cpu().squeeze() print('\n\nSample:') print('-------') text = ''.join(ix_to_char[ix.item()] for ix in samples) print(text) print('-------\n\n')
def train(_run): config = argparse.Namespace(**_run.config) # Initialize the device device = torch.device(config.device) # Initialize the dataset and data loader (note the +1) dataset = TextDataset(config.txt_file, config.seq_length) total_samples = int(config.train_steps * config.batch_size) sampler = RandomSampler(dataset, replacement=True, num_samples=total_samples) data_sampler = BatchSampler(sampler, config.batch_size, drop_last=False) data_loader = DataLoader(dataset, num_workers=1, batch_sampler=data_sampler) # Initialize the model that we are going to use model = TextGenerationModel(dataset.vocab_size, config.lstm_num_hidden, config.lstm_num_layers).to(device) # Setup the loss and optimizer criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=config.learning_rate) for step, (batch_inputs, batch_targets) in enumerate(data_loader): # Only for time measurement of step through network t1 = time.time() # Prepare data batch_inputs = torch.stack(batch_inputs).to(device) batch_targets = torch.stack(batch_targets).t().to(device) # Forward, backward, optimize optimizer.zero_grad() logits = model(batch_inputs) batch_loss = criterion(logits, batch_targets) batch_loss.backward() optimizer.step() # Just for time measurement t2 = time.time() examples_per_second = config.batch_size / float(t2 - t1) if step % config.print_every == 0: accuracy = eval_accuracy(logits, batch_targets) loss = batch_loss.item() log_str = ("[{}] Train Step {:04d}/{:04d}, " "Batch Size = {}, Examples/Sec = {:.2f}, " "Accuracy = {:.2f}, Loss = {:.3f}") print( log_str.format(datetime.now().strftime("%Y-%m-%d %H:%M"), step, config.train_steps, config.batch_size, examples_per_second, accuracy, loss)) _run.log_scalar('loss', loss, step) _run.log_scalar('acc', accuracy, step) if step % config.sample_every == 0: # Generate some sentences by sampling from the model print('-' * (config.sample_length + 1)) x0 = torch.randint(low=0, high=dataset.vocab_size, size=(1, 5)) samples = model.sample(x0, config.sample_length).detach().cpu() samples = samples.numpy() for sample in samples: print(dataset.convert_to_string(sample)) print('-' * (config.sample_length + 1)) if step == config.train_steps: break print('Done training.') ckpt_path = os.path.join(SAVE_PATH, str(config.timestamp) + '.pt') torch.save( { 'state_dict': model.state_dict(), 'hparams': model.hparams, 'ix_to_char': dataset.ix_to_char }, ckpt_path) print('Saved checkpoint to {}'.format(ckpt_path))