def train(config): seed = 42 torch.manual_seed(seed) np.random.seed(seed) # Initialize the device which to run the model on device = torch.device(config.device) writer = SummaryWriter() seq_length = config.seq_length batch_size = config.batch_size lstm_num_hidden = config.lstm_num_hidden lstm_num_layers = config.lstm_num_layers dropout_keep_prob = config.dropout_keep_prob # Initialize the dataset and data loader (note the +1) dataset = TextDataset(config.txt_file, seq_length) data_loader = DataLoader(dataset, batch_size, num_workers=1) vocab_size = dataset.vocab_size # Initialize the model that we are going to use model = TextGenerationModel(batch_size, seq_length, vocab_size, lstm_num_hidden, lstm_num_layers, dropout_keep_prob, device) model.to(device) # Setup the loss and optimizer criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=config.learning_rate) lr_scheduler = optim.lr_scheduler.StepLR(optimizer, config.learning_rate_step, config.learning_rate_decay) for step, (batch_inputs, batch_targets) in enumerate(data_loader): # Only for time measurement of step through network t1 = time.time() ####################################################### # Add more code here ... ####################################################### # To onehot represetation of input or embedding => decided for embedding # batch_inputs = F.one_hot(batch_inputs, vocab_size).type(torch.FloatTensor).to(device) batch_inputs = batch_inputs.to(device) batch_targets = batch_targets.to(device) train_output, _ = model.forward(batch_inputs) loss = criterion(train_output, batch_targets) accuracy = torch.sum( torch.eq(torch.argmax(train_output, dim=1), batch_targets)).item() / (batch_targets.size(0) * batch_targets.size(1)) writer.add_scalar('Loss/train', loss.item(), step) writer.add_scalar('Accuracy/train', accuracy, step) optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step(step) # Just for time measurement t2 = time.time() examples_per_second = config.batch_size / float(t2 - t1) if step % config.print_every == 0: print( "[{}] Train Step {:04d}/{:04d}, Batch Size = {}, Examples/Sec = {:.2f}, " "Accuracy = {:.2f}, Loss = {:.3f}".format( datetime.now().strftime("%Y-%m-%d %H:%M"), step, int(config.train_steps), config.batch_size, examples_per_second, accuracy, loss)) if step % config.sample_every == 0: # Generate some sentences by sampling from the model sample_from_model(config, step, model, dataset) if step == config.train_steps: # If you receive a PyTorch data-loader error, check this bug report: # https://github.com/pytorch/pytorch/pull/9655 break print('Done training.') torch.save(model, "trained_model_part2.pth") writer.close()
def train(config): # Initialize the device which to run the model on # device = torch.device(config.device) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') dataset = TextDataset(filename=config.txt_file, seq_length=config.seq_length) data_loader = DataLoader(dataset, config.batch_size, num_workers=1) VOCAB_SIZE = dataset.vocab_size CHAR2IDX = dataset._char_to_ix IDX2CHAR = dataset._ix_to_char # Initialize the model that we are going to use model = TextGenerationModel(batch_size=config.batch_size, seq_length=config.seq_length, vocabulary_size=VOCAB_SIZE, lstm_num_hidden=config.lstm_num_hidden, lstm_num_layers=config.lstm_num_layers, device=device) # Setup the loss and optimizer criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate) scheduler = scheduler_lib.StepLR(optimizer=optimizer, step_size=config.learning_rate_step, gamma=config.learning_rate_decay) if True: model.load_state_dict( torch.load('grimm-results/intermediate-model-epoch-30-step-0.pth', map_location='cpu')) optimizer.load_state_dict( torch.load("grimm-results/intermediate-optim-epoch-30-step-0.pth", map_location='cpu')) print("Loaded it!") model = model.to(device) EPOCHS = 50 for epoch in range(EPOCHS): # initialization of state that's given to the forward pass # reset every epoch h, c = model.reset_lstm(config.batch_size) h = h.to(device) c = c.to(device) for step, (batch_inputs, batch_targets) in enumerate(data_loader): # Only for time measurement of step through network t1 = time.time() model.train() optimizer.zero_grad() x = torch.stack(batch_inputs, dim=1).to(device) if x.size()[0] != config.batch_size: print("We're breaking because something is wrong") print("Current batch is of size {}".format(x.size()[0])) print("Supposed batch size is {}".format(config.batch_size)) break y = torch.stack(batch_targets, dim=1).to(device) x = one_hot_encode(x, VOCAB_SIZE) output, (h, c) = model(x=x, prev_state=(h, c)) loss = criterion(output.transpose(1, 2), y) accuracy = calculate_accuracy(output, y) h = h.detach() c = c.detach() loss.backward() # add clipping torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.max_norm) optimizer.step() scheduler.step() # Just for time measurement t2 = time.time() examples_per_second = config.batch_size / float(t2 - t1) if step % config.print_every == 0: #TODO FIX THIS PRINTING print( f"Epoch {epoch} Train Step {step}/{config.train_steps}, Examples/Sec = {examples_per_second}, Accuracy = {accuracy}, Loss = {loss}" ) # # print("[{}]".format(datetime.now().strftime("%Y-%m-%d %H:%M"))) # print("[{}] Train Step {:04f}/{:04f}, Batch Size = {}, Examples/Sec = {:.2f}, Accuracy = {:.2f}, Loss = {:.3f}".format( # datetime.now().strftime("%Y-%m-%d %H:%M"), step, config.train_steps, config.batch_size, examples_per_second, accuracy, loss # )) # print(loss) if step % config.sample_every == 0: FIRST_CHAR = 'I' # Is randomized within the prediction, actually predict(device, model, FIRST_CHAR, VOCAB_SIZE, IDX2CHAR, CHAR2IDX) # Generate some sentences by sampling from the model path_model = 'intermediate-model-epoch-{}-step-{}.pth'.format( epoch, step) path_optimizer = 'intermediate-optim-epoch-{}-step-{}.pth'.format( epoch, step) torch.save(model.state_dict(), path_model) torch.save(optimizer.state_dict(), path_optimizer) if step == config.train_steps: # If you receive a PyTorch data-loader error, check this bug report: # https://github.com/pytorch/pytorch/pull/9655 break print('Done training.')
def train(config): # Initialize the device which to run the model on #device = torch.device(config.device) # Initialize the dataset and data loader (note the +1) dataset = TextDataset(config.txt_file, config.seq_length) # fixme data_loader = DataLoader(dataset, config.batch_size, num_workers=1) #print(dataset._char_to_ix) vocabulary order changes, but batches are same sentence examples with the seeds earlier. # Initialize the model that we are going to use model = TextGenerationModel(config.batch_size, config.seq_length, dataset.vocab_size, config.lstm_num_hidden, config.lstm_num_layers, config.device) # fixme device = model.device model = model.to(device) # Setup the loss and optimizer criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.RMSprop(model.parameters(), lr=config.learning_rate) print("Len dataset:", len(dataset)) print("Amount of steps for dataset:", len(dataset) / config.batch_size) current_step = 0 not_max = True list_train_acc = [] list_train_loss = [] acc_average = [] loss_average = [] file = open("sentences.txt", 'w', encoding='utf-8') ''' file_greedy = open("sentences_greedy.txt",'w',encoding='utf-8') file_tmp_05 = open("sentences_tmp_05.txt", 'w', encoding='utf-8') file_tmp_1 = open("sentences_tmp_1.txt", 'w', encoding='utf-8') file_tmp_2 = open("sentences_tmp_2.txt", 'w', encoding='utf-8') ''' while not_max: for (batch_inputs, batch_targets) in data_loader: # Only for time measurement of step through network t1 = time.time() ####################################################### # Add more code here ... #List of indices from word to ID, that is in dataset for embedding #Embedding lookup embed = model.embed #Embeding shape(dataset.vocab_size, config.lstm_num_hidden) #Preprocess input to embeddings to give to LSTM all at once all_embed = [] #sentence = [] for batch_letter in batch_inputs: batch_letter_to = batch_letter.to( device) #torch.tensor(batch_letter,device = device) embedding = embed(batch_letter_to) all_embed.append(embedding) #sentence.append(batch_letter_to[0].item()) all_embed = torch.stack(all_embed) #Print first example sentence of batch along with target #print(dataset.convert_to_string(sentence)) #sentence = [] #for batch_letter in batch_targets: # sentence.append(batch_letter[0].item()) #print(dataset.convert_to_string(sentence)) all_embed = all_embed.to(device) outputs = model( all_embed ) #[30,64,vocab_size] 87 last dimension for fairy tails ####################################################### #loss = np.inf # fixme #accuracy = 0.0 # fixme #For loss: ensuring that the prediction dim are batchsize x vocab_size x sequence length and targets: batchsize x sequence length batch_first_output = outputs.transpose(0, 1).transpose(1, 2) batch_targets = torch.stack(batch_targets).to(device) loss = criterion(batch_first_output, torch.t(batch_targets)) #Backpropagate model.zero_grad() loss.backward() loss = loss.item() torch.nn.utils.clip_grad_norm(model.parameters(), max_norm=config.max_norm) optimizer.step() #Accuracy number_predictions = torch.argmax(outputs, dim=2) result = number_predictions == batch_targets accuracy = result.sum().item() / (batch_targets.shape[0] * batch_targets.shape[1]) '''' #Generate sentences for all settings on every step sentence_id = model.generate_sentence(config.gsen_length, -1) sentence = dataset.convert_to_string(sentence_id) #print(sentence) file_greedy.write( (str(current_step) + ": " + sentence + "\n")) sentence_id = model.generate_sentence(config.gsen_length, 0.5) sentence = dataset.convert_to_string(sentence_id) #print(sentence) file_tmp_05.write( (str(current_step) + ": " + sentence + "\n")) sentence_id = model.generate_sentence(config.gsen_length, 1) sentence = dataset.convert_to_string(sentence_id) #print(sentence) file_tmp_1.write( (str(current_step) + ": " + sentence + "\n")) sentence_id = model.generate_sentence(config.gsen_length, 2) sentence = dataset.convert_to_string(sentence_id) #print(sentence) file_tmp_2.write( (str(current_step) + ": " + sentence + "\n")) ''' if config.measure_type == 2: acc_average.append(accuracy) loss_average.append(loss) # Just for time measurement t2 = time.time() examples_per_second = config.batch_size / float(t2 - t1) if current_step % config.print_every == 0: # Average accuracy and loss over the last print every step (5 by default) if config.measure_type == 2: accuracy = sum(acc_average) / config.print_every loss = sum(loss_average) / config.print_every acc_average = [] loss_average = [] # Either accuracy and loss on the print every interval or the average of that interval as stated above list_train_acc.append(accuracy) list_train_loss.append(loss) print( "[{}] Train Step {:04d}/{:04d}, Batch Size = {}, Examples/Sec = {:.2f}, " "Accuracy = {:.2f}, Loss = {:.3f}".format( datetime.now().strftime("%Y-%m-%d %H:%M"), current_step, config.train_steps, config.batch_size, examples_per_second, accuracy, loss)) elif config.measure_type == 0: # Track accuracy and loss for every step list_train_acc.append(accuracy) list_train_loss.append(loss) if current_step % config.sample_every == 0: # Generate sentence sentence_id = model.generate_sentence(config.gsen_length, config.temperature) sentence = dataset.convert_to_string(sentence_id) print(sentence) file.write((str(current_step) + ": " + sentence + "\n")) if current_step == config.train_steps: # If you receive a PyTorch data-loader error, check this bug report: # https://github.com/pytorch/pytorch/pull/9655 not_max = False break current_step += 1 # Close the file and make sure sentences en measures are saved file.close() pickle.dump((list_train_acc, list_train_loss), open("loss_and_train.p", "wb")) #Plot print(len(list_train_acc)) if config.measure_type == 0: eval_steps = list(range(config.train_steps + 1)) # Every step Acc else: # eval_steps = list( range(0, config.train_steps + config.print_every, config.print_every)) if config.measure_type == 2: plt.plot(eval_steps[:-1], list_train_acc[1:], label="Train accuracy") else: plt.plot(eval_steps, list_train_acc, label="Train accuracy") plt.xlabel("Step") plt.ylabel("Accuracy") plt.title("Training accuracy LSTM", fontsize=18, fontweight="bold") plt.legend() # plt.savefig('accuracies.png', bbox_inches='tight') plt.show() if config.measure_type == 2: plt.plot(eval_steps[:-1], list_train_loss[1:], label="Train loss") else: plt.plot(eval_steps, list_train_loss, label="Train loss") plt.xlabel("Step") plt.ylabel("Loss") plt.title("Training loss LSTM", fontsize=18, fontweight="bold") plt.legend() # plt.savefig('loss.png', bbox_inches='tight') plt.show() print('Done training.')