class GAN(object): def __init__(self, D_hidden_dim, G_hidden_dim, z_dim, hyperparams={}, dataset='mnist', image_dim=None): self.dataset = dataset if dataset.lower() == 'mnist': image_dim = 28 * 28 self.digit = hyperparams.get("digit", 2) elif dataset.lower() == 'celeba_bw': print("This basic GAN version probably will not converge. \ See TTitcombe/GANmodels for more powerful versions \ (in development)") image_dim = 178 * 218 elif dataset is None and image_dim is None: raise RuntimeError("You must either define a recognised dataset \ or define an input image dimension") else: raise NotImplementedError("The dataset you have selected \ is not recognised") self.epochs = hyperparams.get("epochs", 100) self.batchSize = hyperparams.get("batchSize", 64) self.lr = hyperparams.get("lr", 0.001) self.decay = hyperparams.get("decay", 1.) self.epsilon = hyperparams.get("epsilon", 1e-7) #avoid overflow self.D = ANN(image_dim, D_hidden_dim, 1, self.lr, False) self.G = ANN(z_dim, G_hidden_dim, image_dim, self.lr, True) def train(self, X_train=None): if self.dataset.lower() == 'mnist': X_train, N_train = load_mnist(self.digit) np.random.shuffle(X_train) elif self.dataset.lower() == 'celeba_bw': #X_train is a path to the images _, _, filenames = os.walk(X_train) N_train = len(filenames) elif X_train is None: raise RuntimeError("X training data must be provided") else: N_train = X_train.shape[0] np.random.shuffle(X_train) N_batch = N_train // self.batchSize for epoch in range(self.epochs): g_loss_tracker = [0.] g_loss_differences = [] d_loss_tracker = [] for step in range(N_batch): if self.dataset.lower() == 'celeba_bw': file = filenames[step] path = X_train + file X_batch = cv2.imread(path) X_batch = cv2.cvtColor(X_batch, cv2.COLOR_RGB2GRAY) else: X_batch = X_train[step * self.batchSize:(1 + step) * self.batchSize] if X_batch.shape[0] != self.batchSize: break #Generate random (normal) z z = np.random.normal(loc=0.0, scale=0.5, size=(self.batchSize, 100)) z[z < -1] = -1. z[z > 1] = 1. #Feedforward g_logits, fake_img = self.G._feedforward(z) d_real_logits, d_real_output = self.D._feedforward(X_batch) d_fake_logits, d_fake_output = self.D._feedforward( fake_img, True) d_loss = -np.log(d_real_output + self.epsilon) - np.log( 1 - d_fake_output + self.epsilon) #track D loss: Failure if 0; varying wildly is probably bad. assert np.mean( d_loss) > 1e-8, "D loss has gone to zero - Failure case" d_loss_tracker.append(d_loss) if step > 9: d_loss_tracker.pop(0) if np.std(d_loss_tracker) > 5: #what is "varying wildly" warnings.warn("D loss is varying sharply", UserWarning) g_loss = -np.log(d_fake_output + self.epsilon) #track G loss: "if it steadily decreases, it's fooling D with garbage" g_loss_differences.append(np.mean(g_loss - g_loss_tracker[-1])) if step > 8: g_loss_tracker.pop(0) if np.std(g_loss_differences) < 1e-2: warnings.warn( "G loss is decreasing steadily. Check G output.", UserWarning) g_loss_tracker.append(np.mean(g_loss)) #Update with decayed learning rate self.G.setLR(self.lr) self.D.setLR(self.lr) #Backprop d_archs = self.D.archs d_n_layers = self.D.N_layers d_lin_store = self.D.fake_lin_store d_act_store = self.D.fake_act_store self.D.backprop() self.G.backprop(d_n_layers, d_act_store, d_lin_store, d_archs) #Show samples samples = self.sample() full_image = show_samples(samples, 25, self.dataset) cv2.imshow('Samples', full_image) cv2.waitKey(1) #fid = FID(samples, X_train_reshaped) print( "Epoch: %d; Step: %d; G Loss: %.4f; D Loss: %.4f; Real ac: %.4f; Fake ac: %.4f" % (epoch, step, np.mean(g_loss), np.mean(d_loss), np.mean(d_real_output), np.mean(d_fake_output))) self.lr = self.lr * self.decay def sample(self): z = np.random.normal(loc=0.0, scale=0.5, size=(self.batchSize, 100)) z[z < -1] = -1. z[z > 1] = 1. _, fake_img = self.G._feedforward(z) return fake_img