Example #1
0
    def runTest(self):
        encoder = vae.VariationalAutoEncoder(
            self.latent_dim,
            preprocess_layers=tf.keras.layers.Dense(512, activation='relu'))

        decoding_layers = tf.keras.Sequential([
            tf.keras.layers.Dense(512, activation='relu'),
            tf.keras.layers.Dense(self.original_dim, activation='sigmoid')
        ])

        inputs = tf.keras.layers.Input(shape=self.input_shape)
        z, kl_loss = encoder.sampling_forward(inputs)
        outputs = decoding_layers(z)

        loss = tf.reduce_mean(
            tf.keras.losses.mse(inputs, outputs) * self.original_dim + kl_loss)

        model = tf.keras.Model(inputs, outputs, name="vae")
        model.add_loss(loss)
        model.compile(optimizer='adam')
        model.summary()

        hist = model.fit(self.x_train,
                         epochs=self.epochs,
                         batch_size=self.batch_size,
                         validation_data=(self.x_test, None))

        last_val_loss = hist.history['val_loss'][-1]
        print("loss: ", last_val_loss)
        self.assertTrue(37.5 < last_val_loss <= 39.0)
        if INTERACTIVE_MODE:
            self.show_encoded_images(model)
            self.show_sampled_images(lambda eps: decoding_layers(eps))
Example #2
0
    def test_vae(self):
        """Test for one dimensional Gaussion."""
        encoder = vae.VariationalAutoEncoder(
            self._latent_dim, input_tensor_spec=self._input_spec)
        decoding_layers = FC(self._latent_dim, 1)

        optimizer = torch.optim.Adam(
            list(encoder.parameters()) + list(decoding_layers.parameters()),
            lr=0.1)

        x_train = self._input_spec.randn(outer_dims=(10000, ))
        x_test = self._input_spec.randn(outer_dims=(10, ))

        for _ in range(self._epochs):
            x_train = x_train[torch.randperm(x_train.shape[0])]
            for i in range(0, x_train.shape[0], self._batch_size):
                optimizer.zero_grad()
                batch = x_train[i:i + self._batch_size]
                alg_step = encoder.train_step(batch)
                outputs = decoding_layers(alg_step.output)
                loss = torch.mean(100 * self._loss_f(batch - outputs) +
                                  alg_step.info.loss)
                loss.backward()
                optimizer.step()

        y_test = decoding_layers(encoder.train_step(x_test).output)
        reconstruction_loss = float(torch.mean(self._loss_f(x_test - y_test)))
        print("reconstruction_loss:", reconstruction_loss)
        self.assertLess(reconstruction_loss, 0.05)
Example #3
0
    def test_gaussian(self):
        """Test for one dimensional Gaussion."""
        input_shape = (1, )
        epochs = 20
        batch_size = 100
        latent_dim = 1
        loss_f = tf.square

        encoder = vae.VariationalAutoEncoder(latent_dim)
        decoding_layers = tf.keras.layers.Dense(1)

        inputs = tf.keras.layers.Input(shape=input_shape)
        z, kl_loss = encoder.sampling_forward(inputs)
        outputs = decoding_layers(z)

        loss = tf.reduce_mean(100 * loss_f(inputs - outputs) + kl_loss)

        model = tf.keras.Model(inputs, outputs, name="vae")
        model.add_loss(loss)
        model.compile(optimizer=tf.optimizers.Adam(learning_rate=0.1))
        model.summary()

        x_train = np.random.randn(10000, 1)
        x_val = np.random.randn(10000, 1)
        x_test = np.random.randn(10, 1)
        y_test = model(x_test.astype(np.float32))

        hist = model.fit(
            x_train,
            epochs=epochs,
            batch_size=batch_size,
            validation_data=(x_val, None))

        y_test = model(x_test.astype(np.float32))
        reconstruction_loss = float(tf.reduce_mean(loss_f(x_test - y_test)))
        print("reconstruction_loss:", reconstruction_loss)
        self.assertLess(reconstruction_loss, 0.05)
Example #4
0
    def runTest(self):
        prior_network = PriorNetwork(self.latent_dim)
        encoder = vae.VariationalAutoEncoder(
            self.latent_dim,
            prior_network=prior_network,
            preprocess_layers=tf.keras.layers.Dense(512, activation='relu'))

        decoding_layers = tf.keras.Sequential([
            tf.keras.layers.Dense(512, activation='relu'),
            tf.keras.layers.Dense(self.original_dim, activation='sigmoid')
        ])

        inputs = tf.keras.layers.Input(shape=self.input_shape)
        prior_inputs = tf.keras.layers.Input(shape=(1, ))
        prior_inputs_one_hot = tf.reshape(tf.one_hot(
            tf.cast(prior_inputs, tf.int32), 10),
                                          shape=(-1, 10))
        z, kl_loss = encoder.sampling_forward((prior_inputs_one_hot, inputs))
        outputs = decoding_layers(z)

        loss = tf.reduce_mean(
            tf.keras.losses.mse(inputs, outputs) * self.original_dim + kl_loss)

        model = tf.keras.Model(inputs=[prior_inputs, inputs],
                               outputs=outputs,
                               name="vae")
        model.add_loss(loss)
        model.compile(optimizer='adam')
        model.summary()

        hist = model.fit([self.y_train, self.x_train],
                         epochs=self.epochs,
                         batch_size=self.batch_size,
                         validation_data=([self.y_test, self.x_test], None))

        last_val_loss = hist.history['val_loss'][-1]
        # total loss is much smaller with label based prior network.
        print("loss: ", last_val_loss)
        self.assertTrue(34.0 < last_val_loss < 35.5)
        if INTERACTIVE_MODE:
            self.show_encoded_images(model, with_priors=True)

            # with prior network, sampling is more complicated
            nrows = 10
            fig = plt.figure()
            idx = 0
            for i in range(10):
                z_mean_prior, z_log_var_prior = prior_network(
                    tf.stack([tf.one_hot(i, 10) for _ in range(10)]))

                eps = tf.random.normal((nrows, self.latent_dim),
                                       dtype=tf.float32,
                                       mean=0.,
                                       stddev=1.0)
                sampled_outputs = decoding_layers(eps * z_log_var_prior +
                                                  z_mean_prior)
                # for the same digit i, we sample a bunch of images
                # it actually looks great.
                for j in range(nrows):
                    fig.add_subplot(nrows, nrows, idx + 1)
                    plt.imshow(
                        np.reshape(sampled_outputs[j],
                                   (self.image_size, self.image_size)))
                    idx += 1

            plt.show()
Example #5
0
    def runTest(self):
        encoder = vae.VariationalAutoEncoder(
            self.latent_dim,
            preprocess_layers=tf.keras.layers.Dense(512, activation='relu'))

        decoding_layers = tf.keras.Sequential([
            tf.keras.layers.Dense(512, activation='relu'),
            tf.keras.layers.Dense(self.original_dim, activation='sigmoid')
        ])

        inputs = tf.keras.layers.Input(shape=self.input_shape)
        prior_inputs = tf.keras.layers.Input(shape=(1, ))
        prior_inputs_one_hot = tf.reshape(tf.one_hot(
            tf.cast(prior_inputs, tf.int32), 10),
                                          shape=(-1, 10))

        z, kl_loss = encoder.sampling_forward(
            tf.concat([prior_inputs_one_hot, inputs], -1))
        outputs = decoding_layers(tf.concat([prior_inputs_one_hot, z], -1))

        loss = tf.reduce_mean(
            tf.keras.losses.mse(inputs, outputs) * self.original_dim + kl_loss)

        model = tf.keras.Model(inputs=[prior_inputs, inputs],
                               outputs=outputs,
                               name="vae")
        model.add_loss(loss)
        model.compile(optimizer='adam')
        model.summary()

        hist = model.fit([self.y_train, self.x_train],
                         epochs=self.epochs,
                         batch_size=self.batch_size,
                         validation_data=([self.y_test, self.x_test], None))

        last_val_loss = hist.history['val_loss'][-1]
        # cvae seems have the lowest errors with the same settings
        print("loss: ", last_val_loss)
        self.assertTrue(30.0 < last_val_loss < 31.0)

        if INTERACTIVE_MODE:
            self.show_encoded_images(model, with_priors=True)
            nrows = 10
            fig = plt.figure()
            idx = 0
            for i in range(10):
                eps = tf.random.normal((nrows, self.latent_dim),
                                       dtype=tf.float32,
                                       mean=0.,
                                       stddev=1.0)
                conditionals = tf.stack([tf.one_hot(i, 10) for _ in range(10)])
                sampled_outputs = decoding_layers(
                    tf.concat([conditionals, eps], -1))
                # for the same digit i, we sample a bunch of images
                # it actually looks great.
                for j in range(nrows):
                    fig.add_subplot(nrows, nrows, idx + 1)
                    plt.imshow(
                        np.reshape(sampled_outputs[j],
                                   (self.image_size, self.image_size)))
                    idx += 1

            plt.show()
Example #6
0
    def test_conditional_vae(self):
        """Test for one dimensional Gaussion, conditioned on a Bernoulli variable.
        """
        prior_input_spec = BoundedTensorSpec((), 'int64')

        z_prior_network = EncodingNetwork(
            TensorSpec(
                (prior_input_spec.maximum - prior_input_spec.minimum + 1, )),
            fc_layer_params=(10, ) * 2,
            last_layer_size=2 * self._latent_dim,
            last_activation=math_ops.identity)
        preprocess_network = EncodingNetwork(
            input_tensor_spec=(
                z_prior_network.input_tensor_spec,
                self._input_spec,
                z_prior_network.output_spec,
            ),
            preprocessing_combiner=NestConcat(),
            fc_layer_params=(10, ) * 2,
            last_layer_size=self._latent_dim,
            last_activation=math_ops.identity)

        encoder = vae.VariationalAutoEncoder(
            self._latent_dim,
            preprocess_network=preprocess_network,
            z_prior_network=z_prior_network)
        decoding_layers = FC(self._latent_dim, 1)

        optimizer = torch.optim.Adam(
            list(encoder.parameters()) + list(decoding_layers.parameters()),
            lr=0.1)

        x_train = self._input_spec.randn(outer_dims=(10000, ))
        y_train = x_train.clone()
        y_train[:5000] = y_train[:5000] + 1.0
        pr_train = torch.cat([
            prior_input_spec.zeros(outer_dims=(5000, )),
            prior_input_spec.ones(outer_dims=(5000, ))
        ],
                             dim=0)

        x_test = self._input_spec.randn(outer_dims=(100, ))
        y_test = x_test.clone()
        y_test[:50] = y_test[:50] + 1.0
        pr_test = torch.cat([
            prior_input_spec.zeros(outer_dims=(50, )),
            prior_input_spec.ones(outer_dims=(50, ))
        ],
                            dim=0)
        pr_test = torch.nn.functional.one_hot(
            pr_test,
            int(z_prior_network.input_tensor_spec.shape[0])).to(torch.float32)

        for _ in range(self._epochs):
            idx = torch.randperm(x_train.shape[0])
            x_train = x_train[idx]
            y_train = y_train[idx]
            pr_train = pr_train[idx]
            for i in range(0, x_train.shape[0], self._batch_size):
                optimizer.zero_grad()
                batch = x_train[i:i + self._batch_size]
                y_batch = y_train[i:i + self._batch_size]
                pr_batch = torch.nn.functional.one_hot(
                    pr_train[i:i + self._batch_size],
                    int(z_prior_network.input_tensor_spec.shape[0])).to(
                        torch.float32)
                alg_step = encoder.train_step([pr_batch, batch])
                outputs = decoding_layers(alg_step.output)
                loss = torch.mean(100 * self._loss_f(y_batch - outputs) +
                                  alg_step.info.loss)
                loss.backward()
                optimizer.step()

        y_hat_test = decoding_layers(
            encoder.train_step([pr_test, x_test]).output)
        reconstruction_loss = float(
            torch.mean(self._loss_f(y_test - y_hat_test)))
        print("reconstruction_loss:", reconstruction_loss)
        self.assertLess(reconstruction_loss, 0.05)