def test_generator_training(): # parameters file_name = "animals.txt" genr_hidden_size = 10 disr_hidden_size = 3 num_epochs_d = 20 num_epochs_g = 20 lr = 1 alpha = 0.9 batch_size = 100 # load data char_list = dataloader.get_char_list(file_name) X_actual = dataloader.load_data(file_name) num_examples = X_actual.shape[0] seq_len = X_actual.shape[1] # generate genr_input = np.random.randn(num_examples, len(char_list)) genr = Generator(genr_hidden_size, char_list) X_generated = genr.generate_tensor(seq_len, num_examples, genr_input) # train discriminator disr = Discriminator(len(char_list), disr_hidden_size) disr.train_RMS(X_actual, X_generated, num_epochs_d, lr, alpha, batch_size) # evaluate discriminator accuracy = disr.accuracy(X_actual, X_generated) print("accuracy: ", accuracy) # train generator genr.train_RMS(genr_input, seq_len, disr, num_epochs_g, 1, lr, alpha, batch_size, print_progress=True) # evaluate discriminator again X_generated = genr.generate_tensor(seq_len, num_examples, genr_input) accuracy = disr.accuracy(X_actual, X_generated) print("accuracy: ", accuracy)
def test_discriminator(): # parameters file_name = "animals.txt" genr_hidden_size = 10 disr_hidden_size = 11 num_epochs = 20 lr = 1 alpha = 0.9 batch_size = 100 # load data char_list = dataloader.get_char_list(file_name) X_actual = dataloader.load_data(file_name) num_examples = X_actual.shape[0] seq_len = X_actual.shape[1] # generate genr = Generator(genr_hidden_size, char_list) X_generated = genr.generate_tensor(seq_len, num_examples) # train discriminator disr = Discriminator(len(char_list), disr_hidden_size) disr.train_RMS(X_actual, X_generated, num_epochs, lr, alpha, batch_size, print_progress=True) # print discriminator output outp = disr.discriminate(np.concatenate((X_actual, X_generated), axis=0)) print(outp) # evaluate discriminator accuracy = disr.accuracy(X_actual, X_generated) print("accuracy: ", accuracy)
class GAN: # g_hidden_size: size of hidden layer in generator # d_hidden_size: size of hidden layer in discriminator # char_list: list of characters the generator can generate def __init__(self, g_hidden_size, d_hidden_size, char_list): self.char_list = char_list self.generator = Generator(g_hidden_size, char_list) self.discriminator = Discriminator(len(char_list), d_hidden_size) # X_actual: input data from dataset (not generated) # n_epochs: total epochs to train entire network # g_epochs: how long to train generator each epoch # d_epochs: how long to train disciminator each epoch # g_initial_lr, g_multiplier: generator RMSprop parameters # d_initial_lr, d_multiplier: discriminator RMSprop parameters # g_batch_size, d_batch_size: batch sizes for generator and discriminator # num_displayed: if print progress is True, this is how many example words # to display - make this None to display all examples def train(self, X_actual, seq_len, n_epochs, g_epochs, d_epochs, g_initial_lr, d_initial_lr, g_multiplier, d_multiplier, g_batch_size, d_batch_size, print_progress=False, num_displayed=None): num_examples = X_actual.shape[0] # TODO: make genr_input change every epoch genr_input = np.random.randn(num_examples, self.generator.input_size) for i in range(n_epochs): # generate text genr_output = self.generator.generate_tensor( seq_len, num_examples, genr_input) # train discriminator self.discriminator.train_RMS(X_actual, genr_output, d_epochs, d_initial_lr, d_multiplier, d_batch_size) # evaluate dicriminator if print_progress: genr_output = self.generator.generate_tensor( seq_len, num_examples, genr_input) accuracy = self.discriminator.accuracy(X_actual, genr_output) print("accuracy before generator training: ", accuracy) # train generator self.generator.train_RMS(genr_input, seq_len, self.discriminator, g_epochs, 1, g_initial_lr, g_multiplier, g_batch_size) #print(sum(l.magnitude_theta() for l in self.generator.lstm.layers)) # evaluate discriminator if print_progress: genr_output = self.generator.generate_tensor( seq_len, num_examples, genr_input) accuracy = self.discriminator.accuracy(X_actual, genr_output) print("accuracy after generator training: ", accuracy) # display generator's output if print_progress: gen_text = self.generator.generate(seq_len, num_examples, genr_input) if num_displayed is not None: gen_text = gen_text[:num_displayed] for line in gen_text: print(line)