def generate(_run): config = argparse.Namespace(**_run.config) # Load trained model and vobulary mappings checkpoint = torch.load(config.model_path, map_location='cpu') model = TextGenerationModel(**checkpoint['hparams']) model.load_state_dict(checkpoint['state_dict']) ix_to_char = checkpoint['ix_to_char'] char_to_ix = {v: k for k, v in ix_to_char.items()} # Prepare initial sequence x0 = torch.tensor([char_to_ix[char] for char in config.initial_seq], dtype=torch.long).view(-1, 1) # Generate samples = model.sample(x0, config.sample_length, config.temperature).detach().cpu().squeeze() print('\n\nSample:') print('-------') text = ''.join(ix_to_char[ix.item()] for ix in samples) print(text) print('-------\n\n')
def train(config): # Initialize the device which to run the model on # device = torch.device(config.device) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') dataset = TextDataset(filename=config.txt_file, seq_length=config.seq_length) data_loader = DataLoader(dataset, config.batch_size, num_workers=1) VOCAB_SIZE = dataset.vocab_size CHAR2IDX = dataset._char_to_ix IDX2CHAR = dataset._ix_to_char # Initialize the model that we are going to use model = TextGenerationModel(batch_size=config.batch_size, seq_length=config.seq_length, vocabulary_size=VOCAB_SIZE, lstm_num_hidden=config.lstm_num_hidden, lstm_num_layers=config.lstm_num_layers, device=device) # Setup the loss and optimizer criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate) scheduler = scheduler_lib.StepLR(optimizer=optimizer, step_size=config.learning_rate_step, gamma=config.learning_rate_decay) if True: model.load_state_dict( torch.load('grimm-results/intermediate-model-epoch-30-step-0.pth', map_location='cpu')) optimizer.load_state_dict( torch.load("grimm-results/intermediate-optim-epoch-30-step-0.pth", map_location='cpu')) print("Loaded it!") model = model.to(device) EPOCHS = 50 for epoch in range(EPOCHS): # initialization of state that's given to the forward pass # reset every epoch h, c = model.reset_lstm(config.batch_size) h = h.to(device) c = c.to(device) for step, (batch_inputs, batch_targets) in enumerate(data_loader): # Only for time measurement of step through network t1 = time.time() model.train() optimizer.zero_grad() x = torch.stack(batch_inputs, dim=1).to(device) if x.size()[0] != config.batch_size: print("We're breaking because something is wrong") print("Current batch is of size {}".format(x.size()[0])) print("Supposed batch size is {}".format(config.batch_size)) break y = torch.stack(batch_targets, dim=1).to(device) x = one_hot_encode(x, VOCAB_SIZE) output, (h, c) = model(x=x, prev_state=(h, c)) loss = criterion(output.transpose(1, 2), y) accuracy = calculate_accuracy(output, y) h = h.detach() c = c.detach() loss.backward() # add clipping torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.max_norm) optimizer.step() scheduler.step() # Just for time measurement t2 = time.time() examples_per_second = config.batch_size / float(t2 - t1) if step % config.print_every == 0: #TODO FIX THIS PRINTING print( f"Epoch {epoch} Train Step {step}/{config.train_steps}, Examples/Sec = {examples_per_second}, Accuracy = {accuracy}, Loss = {loss}" ) # # print("[{}]".format(datetime.now().strftime("%Y-%m-%d %H:%M"))) # print("[{}] Train Step {:04f}/{:04f}, Batch Size = {}, Examples/Sec = {:.2f}, Accuracy = {:.2f}, Loss = {:.3f}".format( # datetime.now().strftime("%Y-%m-%d %H:%M"), step, config.train_steps, config.batch_size, examples_per_second, accuracy, loss # )) # print(loss) if step % config.sample_every == 0: FIRST_CHAR = 'I' # Is randomized within the prediction, actually predict(device, model, FIRST_CHAR, VOCAB_SIZE, IDX2CHAR, CHAR2IDX) # Generate some sentences by sampling from the model path_model = 'intermediate-model-epoch-{}-step-{}.pth'.format( epoch, step) path_optimizer = 'intermediate-optim-epoch-{}-step-{}.pth'.format( epoch, step) torch.save(model.state_dict(), path_model) torch.save(optimizer.state_dict(), path_optimizer) if step == config.train_steps: # If you receive a PyTorch data-loader error, check this bug report: # https://github.com/pytorch/pytorch/pull/9655 break print('Done training.')