def runTest(self): encoder = vae.VariationalAutoEncoder( self.latent_dim, preprocess_layers=tf.keras.layers.Dense(512, activation='relu')) decoding_layers = tf.keras.Sequential([ tf.keras.layers.Dense(512, activation='relu'), tf.keras.layers.Dense(self.original_dim, activation='sigmoid') ]) inputs = tf.keras.layers.Input(shape=self.input_shape) z, kl_loss = encoder.sampling_forward(inputs) outputs = decoding_layers(z) loss = tf.reduce_mean( tf.keras.losses.mse(inputs, outputs) * self.original_dim + kl_loss) model = tf.keras.Model(inputs, outputs, name="vae") model.add_loss(loss) model.compile(optimizer='adam') model.summary() hist = model.fit(self.x_train, epochs=self.epochs, batch_size=self.batch_size, validation_data=(self.x_test, None)) last_val_loss = hist.history['val_loss'][-1] print("loss: ", last_val_loss) self.assertTrue(37.5 < last_val_loss <= 39.0) if INTERACTIVE_MODE: self.show_encoded_images(model) self.show_sampled_images(lambda eps: decoding_layers(eps))
def test_vae(self): """Test for one dimensional Gaussion.""" encoder = vae.VariationalAutoEncoder( self._latent_dim, input_tensor_spec=self._input_spec) decoding_layers = FC(self._latent_dim, 1) optimizer = torch.optim.Adam( list(encoder.parameters()) + list(decoding_layers.parameters()), lr=0.1) x_train = self._input_spec.randn(outer_dims=(10000, )) x_test = self._input_spec.randn(outer_dims=(10, )) for _ in range(self._epochs): x_train = x_train[torch.randperm(x_train.shape[0])] for i in range(0, x_train.shape[0], self._batch_size): optimizer.zero_grad() batch = x_train[i:i + self._batch_size] alg_step = encoder.train_step(batch) outputs = decoding_layers(alg_step.output) loss = torch.mean(100 * self._loss_f(batch - outputs) + alg_step.info.loss) loss.backward() optimizer.step() y_test = decoding_layers(encoder.train_step(x_test).output) reconstruction_loss = float(torch.mean(self._loss_f(x_test - y_test))) print("reconstruction_loss:", reconstruction_loss) self.assertLess(reconstruction_loss, 0.05)
def test_gaussian(self): """Test for one dimensional Gaussion.""" input_shape = (1, ) epochs = 20 batch_size = 100 latent_dim = 1 loss_f = tf.square encoder = vae.VariationalAutoEncoder(latent_dim) decoding_layers = tf.keras.layers.Dense(1) inputs = tf.keras.layers.Input(shape=input_shape) z, kl_loss = encoder.sampling_forward(inputs) outputs = decoding_layers(z) loss = tf.reduce_mean(100 * loss_f(inputs - outputs) + kl_loss) model = tf.keras.Model(inputs, outputs, name="vae") model.add_loss(loss) model.compile(optimizer=tf.optimizers.Adam(learning_rate=0.1)) model.summary() x_train = np.random.randn(10000, 1) x_val = np.random.randn(10000, 1) x_test = np.random.randn(10, 1) y_test = model(x_test.astype(np.float32)) hist = model.fit( x_train, epochs=epochs, batch_size=batch_size, validation_data=(x_val, None)) y_test = model(x_test.astype(np.float32)) reconstruction_loss = float(tf.reduce_mean(loss_f(x_test - y_test))) print("reconstruction_loss:", reconstruction_loss) self.assertLess(reconstruction_loss, 0.05)
def runTest(self): prior_network = PriorNetwork(self.latent_dim) encoder = vae.VariationalAutoEncoder( self.latent_dim, prior_network=prior_network, preprocess_layers=tf.keras.layers.Dense(512, activation='relu')) decoding_layers = tf.keras.Sequential([ tf.keras.layers.Dense(512, activation='relu'), tf.keras.layers.Dense(self.original_dim, activation='sigmoid') ]) inputs = tf.keras.layers.Input(shape=self.input_shape) prior_inputs = tf.keras.layers.Input(shape=(1, )) prior_inputs_one_hot = tf.reshape(tf.one_hot( tf.cast(prior_inputs, tf.int32), 10), shape=(-1, 10)) z, kl_loss = encoder.sampling_forward((prior_inputs_one_hot, inputs)) outputs = decoding_layers(z) loss = tf.reduce_mean( tf.keras.losses.mse(inputs, outputs) * self.original_dim + kl_loss) model = tf.keras.Model(inputs=[prior_inputs, inputs], outputs=outputs, name="vae") model.add_loss(loss) model.compile(optimizer='adam') model.summary() hist = model.fit([self.y_train, self.x_train], epochs=self.epochs, batch_size=self.batch_size, validation_data=([self.y_test, self.x_test], None)) last_val_loss = hist.history['val_loss'][-1] # total loss is much smaller with label based prior network. print("loss: ", last_val_loss) self.assertTrue(34.0 < last_val_loss < 35.5) if INTERACTIVE_MODE: self.show_encoded_images(model, with_priors=True) # with prior network, sampling is more complicated nrows = 10 fig = plt.figure() idx = 0 for i in range(10): z_mean_prior, z_log_var_prior = prior_network( tf.stack([tf.one_hot(i, 10) for _ in range(10)])) eps = tf.random.normal((nrows, self.latent_dim), dtype=tf.float32, mean=0., stddev=1.0) sampled_outputs = decoding_layers(eps * z_log_var_prior + z_mean_prior) # for the same digit i, we sample a bunch of images # it actually looks great. for j in range(nrows): fig.add_subplot(nrows, nrows, idx + 1) plt.imshow( np.reshape(sampled_outputs[j], (self.image_size, self.image_size))) idx += 1 plt.show()
def runTest(self): encoder = vae.VariationalAutoEncoder( self.latent_dim, preprocess_layers=tf.keras.layers.Dense(512, activation='relu')) decoding_layers = tf.keras.Sequential([ tf.keras.layers.Dense(512, activation='relu'), tf.keras.layers.Dense(self.original_dim, activation='sigmoid') ]) inputs = tf.keras.layers.Input(shape=self.input_shape) prior_inputs = tf.keras.layers.Input(shape=(1, )) prior_inputs_one_hot = tf.reshape(tf.one_hot( tf.cast(prior_inputs, tf.int32), 10), shape=(-1, 10)) z, kl_loss = encoder.sampling_forward( tf.concat([prior_inputs_one_hot, inputs], -1)) outputs = decoding_layers(tf.concat([prior_inputs_one_hot, z], -1)) loss = tf.reduce_mean( tf.keras.losses.mse(inputs, outputs) * self.original_dim + kl_loss) model = tf.keras.Model(inputs=[prior_inputs, inputs], outputs=outputs, name="vae") model.add_loss(loss) model.compile(optimizer='adam') model.summary() hist = model.fit([self.y_train, self.x_train], epochs=self.epochs, batch_size=self.batch_size, validation_data=([self.y_test, self.x_test], None)) last_val_loss = hist.history['val_loss'][-1] # cvae seems have the lowest errors with the same settings print("loss: ", last_val_loss) self.assertTrue(30.0 < last_val_loss < 31.0) if INTERACTIVE_MODE: self.show_encoded_images(model, with_priors=True) nrows = 10 fig = plt.figure() idx = 0 for i in range(10): eps = tf.random.normal((nrows, self.latent_dim), dtype=tf.float32, mean=0., stddev=1.0) conditionals = tf.stack([tf.one_hot(i, 10) for _ in range(10)]) sampled_outputs = decoding_layers( tf.concat([conditionals, eps], -1)) # for the same digit i, we sample a bunch of images # it actually looks great. for j in range(nrows): fig.add_subplot(nrows, nrows, idx + 1) plt.imshow( np.reshape(sampled_outputs[j], (self.image_size, self.image_size))) idx += 1 plt.show()
def test_conditional_vae(self): """Test for one dimensional Gaussion, conditioned on a Bernoulli variable. """ prior_input_spec = BoundedTensorSpec((), 'int64') z_prior_network = EncodingNetwork( TensorSpec( (prior_input_spec.maximum - prior_input_spec.minimum + 1, )), fc_layer_params=(10, ) * 2, last_layer_size=2 * self._latent_dim, last_activation=math_ops.identity) preprocess_network = EncodingNetwork( input_tensor_spec=( z_prior_network.input_tensor_spec, self._input_spec, z_prior_network.output_spec, ), preprocessing_combiner=NestConcat(), fc_layer_params=(10, ) * 2, last_layer_size=self._latent_dim, last_activation=math_ops.identity) encoder = vae.VariationalAutoEncoder( self._latent_dim, preprocess_network=preprocess_network, z_prior_network=z_prior_network) decoding_layers = FC(self._latent_dim, 1) optimizer = torch.optim.Adam( list(encoder.parameters()) + list(decoding_layers.parameters()), lr=0.1) x_train = self._input_spec.randn(outer_dims=(10000, )) y_train = x_train.clone() y_train[:5000] = y_train[:5000] + 1.0 pr_train = torch.cat([ prior_input_spec.zeros(outer_dims=(5000, )), prior_input_spec.ones(outer_dims=(5000, )) ], dim=0) x_test = self._input_spec.randn(outer_dims=(100, )) y_test = x_test.clone() y_test[:50] = y_test[:50] + 1.0 pr_test = torch.cat([ prior_input_spec.zeros(outer_dims=(50, )), prior_input_spec.ones(outer_dims=(50, )) ], dim=0) pr_test = torch.nn.functional.one_hot( pr_test, int(z_prior_network.input_tensor_spec.shape[0])).to(torch.float32) for _ in range(self._epochs): idx = torch.randperm(x_train.shape[0]) x_train = x_train[idx] y_train = y_train[idx] pr_train = pr_train[idx] for i in range(0, x_train.shape[0], self._batch_size): optimizer.zero_grad() batch = x_train[i:i + self._batch_size] y_batch = y_train[i:i + self._batch_size] pr_batch = torch.nn.functional.one_hot( pr_train[i:i + self._batch_size], int(z_prior_network.input_tensor_spec.shape[0])).to( torch.float32) alg_step = encoder.train_step([pr_batch, batch]) outputs = decoding_layers(alg_step.output) loss = torch.mean(100 * self._loss_f(y_batch - outputs) + alg_step.info.loss) loss.backward() optimizer.step() y_hat_test = decoding_layers( encoder.train_step([pr_test, x_test]).output) reconstruction_loss = float( torch.mean(self._loss_f(y_test - y_hat_test))) print("reconstruction_loss:", reconstruction_loss) self.assertLess(reconstruction_loss, 0.05)