def train_variational_autoencoder( learning_rate: float, epochs: int, batch_size: int, latent_variables: int = 10, print_every: int = 50, ) -> None: print( f"Training a variational autoencoder for {epochs} epochs with batch size {batch_size}" ) data_loader = DataLoader(batch_size) image_loss = CrossEntropy() divergence_loss = KLDivergenceStandardNormal() encoder_mean = Model([Linear(784, 50), ReLU(), Linear(50, latent_variables)]) encoder_variance = Model( [Linear(784, 50), ReLU(), Linear(50, latent_variables), Exponential()] ) reparameterization = Reparameterization() decoder = Model([Linear(latent_variables, 50), ReLU(), Linear(50, 784)]) for i in range(epochs): # One training loop training_data = data_loader.get_training_data() for j, batch in enumerate(training_data): input, target = batch # Forward pass mean = encoder_mean(input) variance = encoder_variance(input) z = reparameterization(mean=mean, variance=variance) generated_samples = decoder(z) # Loss calculation divergence_loss_value = divergence_loss(mean, variance) generation_loss = image_loss(generated_samples, input) if j % print_every == 0: print( f"Epoch {i+1}/{epochs}, " f"training iteration {j+1}/{len(training_data)}" ) print( f"KL loss {np.round(divergence_loss_value, 2)}\t" f"Generation loss {np.round(generation_loss, 2)}" ) # Backward pass decoder_gradient = image_loss.gradient() decoder_gradient = decoder.backward(decoder_gradient) decoder_mean_gradient, decoder_variance_gradient = reparameterization.backward( decoder_gradient ) encoder_mean_gradient, encoder_variance_gradient = ( divergence_loss.gradient() ) encoder_mean.backward(decoder_mean_gradient + encoder_mean_gradient) encoder_variance.backward( decoder_variance_gradient + encoder_variance_gradient )
def train_classifier(learning_rate: float, epochs: int, batch_size: int, print_every: int = 50) -> None: data_loader = DataLoader(batch_size) loss = CrossEntropy() model = Model([Linear(784, 50), ReLU(), Linear(50, 10)]) for i in range(epochs): # One training loop training_data = data_loader.get_training_data() validation_data = data_loader.get_validation_data() for j, batch in enumerate(training_data): input, target = batch y = model(input) loss(y, target) gradient = loss.gradient() model.backward(gradient) model.update(learning_rate) if j % print_every == 0: print( f"Epoch {i+1}/{epochs}, training iteration {j+1}/{len(training_data)}" ) accuracy_values = [] loss_values = [] # One validation loop for j, batch in enumerate(validation_data): input, target = batch y = model(input) loss_value = loss(y, target) accuracy = calculate_accuracy(y, target) accuracy_values.append(accuracy) loss_values.append(loss_value) print( f"Epoch {i+1}: loss {np.round(np.average(loss_values), 2)}, accuracy {np.round(np.average(accuracy_values), 2)}" )