def run_model(model, training_data, testing_data): #Trains the model for 5 epochs in batches of 10 epochs = 5 training_loss = [] for epoch in range(epochs): train_loss = 0 #Work with the training data in batches of 10 to make program more responsive for batch, (data,target) in enumerate(training_data): optimizer.zero_grad() output=model(data) loss=criterion(output,target) loss.backward() optimizer.step() train_loss += loss.item() if batch%10==9: print("Epoch: {}, Total Images: {}\nAverage Training Loss: {}".format(epoch+1,batch+1,train_loss/10)) training_loss.append(train_loss) train_loss=0.0 print("Total training loss {}".format(training_loss)) #Test the model using the test dataset test_loss = 0 correct = 0 attempted = 0 model.eval() for data,target in testing_data: output = None with torch.no_grad(): output = model(data) processed_output = post_processing(output) for i in range(len(processed_output)): attempted += 1 if processed_output[i] == target[i]: correct += 1 print("Predictions: {}".format(processed_output)) print("Target: {}".format(target)) print("-----") loss = criterion(output,target) test_loss += loss.item() avg_loss=test_loss/attempted print("Average total loss is {:.6f}".format(avg_loss)) print("{} correct predictions out of {} total images".format(correct, attempted))
def run_model_demo(model, training_data, image_path, image_transformer): #Trains the model for 5 epochs in batches of 10 epochs = 5 training_loss = [] for epoch in range(epochs): train_loss = 0 #Work with the training data in batches of 10 to make program more responsive for batch, (data, target) in enumerate(training_data): optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() train_loss += loss.item() if batch % 10 == 9: print( "Epoch: {}, Total Batches: {}\nAverage Training Loss: {}". format(epoch + 1, batch + 1, train_loss / 10)) training_loss.append(train_loss) train_loss = 0.0 print("Total training loss {}".format(training_loss)) #Run the model demo print("The model is trained and ready to make predictions!") model.eval() while True: image = input( "Enter the file name of the image to be evaluated or E to exit: ") if image == "E": print("Thanks for participating!") break user_image = UserData(image_path, image, image_transformer) user_data = DataLoader(user_image) for data in user_data: output = None with torch.no_grad(): output = model(data) processed_output = post_processing(output) guess = ints_to_label[processed_output.item()] print("The guess is: {}".format(guess)) visualize_guess(image, guess)