def loss_aegmm(x_true: tf.Tensor, x_pred: tf.Tensor, z: tf.Tensor, gamma: tf.Tensor, w_energy: float = .1, w_cov_diag: float = .005 ) -> tf.Tensor: """ Loss function used for OutlierAEGMM. Parameters ---------- x_true Batch of instances. x_pred Batch of reconstructed instances by the autoencoder. z Latent space values. gamma Membership prediction for mixture model components. w_energy Weight on sample energy loss term. w_cov_diag Weight on covariance regularizing loss term. Returns ------- Loss value. """ recon_loss = tf.reduce_mean((x_true - x_pred) ** 2) phi, mu, cov, L, log_det_cov = gmm_params(z, gamma) sample_energy, cov_diag = gmm_energy(z, phi, mu, cov, L, log_det_cov, return_mean=True) loss = recon_loss + w_energy * sample_energy + w_cov_diag * cov_diag return loss
def test_gmm_params_energy(): phi, mu, cov, L, log_det_cov = gmm_params(z, gamma) assert phi.numpy().shape[0] == K == log_det_cov.shape[0] assert mu.numpy().shape == (K, D) assert cov.numpy().shape == L.numpy().shape == (K, D, D) for _ in range(cov.numpy().shape[0]): assert (np.diag(cov[_].numpy()) >= 0.).all() assert (np.diag(L[_].numpy()) >= 0.).all() sample_energy, cov_diag = gmm_energy(z, phi, mu, cov, L, log_det_cov, return_mean=True) assert sample_energy.numpy().shape == cov_diag.numpy().shape == () sample_energy, _ = gmm_energy(z, phi, mu, cov, L, log_det_cov, return_mean=False) assert sample_energy.numpy().shape[0] == N
def fit(self, X: np.ndarray, loss_fn: tf.keras.losses = loss_aegmm, w_energy: float = .1, w_cov_diag: float = .005, optimizer: tf.keras.optimizers = tf.keras.optimizers.Adam(learning_rate=1e-4), epochs: int = 20, batch_size: int = 64, verbose: bool = True, log_metric: Tuple[str, "tf.keras.metrics"] = None, callbacks: tf.keras.callbacks = None, ) -> None: """ Train AEGMM model. Parameters ---------- X Training batch. loss_fn Loss function used for training. w_energy Weight on sample energy loss term if default `loss_aegmm` loss fn is used. w_cov_diag Weight on covariance regularizing loss term if default `loss_aegmm` loss fn is used. optimizer Optimizer used for training. epochs Number of training epochs. batch_size Batch size used for training. verbose Whether to print training progress. log_metric Additional metrics whose progress will be displayed if verbose equals True. callbacks Callbacks used during training. """ # train arguments args = [self.aegmm, loss_fn, X] kwargs = {'optimizer': optimizer, 'epochs': epochs, 'batch_size': batch_size, 'verbose': verbose, 'log_metric': log_metric, 'callbacks': callbacks, 'loss_fn_kwargs': {'w_energy': w_energy, 'w_cov_diag': w_cov_diag} } # train trainer(*args, **kwargs) # set GMM parameters x_recon, z, gamma = self.aegmm(X) self.phi, self.mu, self.cov, self.L, self.log_det_cov = gmm_params(z, gamma)
def loss_vaegmm(x_true: tf.Tensor, x_pred: tf.Tensor, z: tf.Tensor, gamma: tf.Tensor, w_recon: float = 1e-7, w_energy: float = .1, w_cov_diag: float = .005, cov_full: tf.Tensor = None, cov_diag: tf.Tensor = None, sim: float = .05) -> tf.Tensor: """ Loss function used for OutlierVAEGMM. Parameters ---------- x_true Batch of instances. x_pred Batch of reconstructed instances by the variational autoencoder. z Latent space values. gamma Membership prediction for mixture model components. w_recon Weight on elbo loss term. w_energy Weight on sample energy loss term. w_cov_diag Weight on covariance regularizing loss term. cov_full Full covariance matrix. cov_diag Diagonal (variance) of covariance matrix. sim Scale identity multiplier. Returns ------- Loss value. """ recon_loss = elbo(x_true, x_pred, cov_full=cov_full, cov_diag=cov_diag, sim=sim) phi, mu, cov, L, log_det_cov = gmm_params(z, gamma) sample_energy, cov_diag = gmm_energy(z, phi, mu, cov, L, log_det_cov) loss = w_recon * recon_loss + w_energy * sample_energy + w_cov_diag * cov_diag return loss
def fit(self, X: np.ndarray, loss_fn: tf.keras.losses = loss_vaegmm, w_recon: float = 1e-7, w_energy: float = .1, w_cov_diag: float = .005, optimizer: tf.keras.optimizers = tf.keras.optimizers.Adam(learning_rate=1e-4), cov_elbo: dict = dict(sim=.05), epochs: int = 20, batch_size: int = 64, verbose: bool = True, log_metric: Tuple[str, "tf.keras.metrics"] = None, callbacks: tf.keras.callbacks = None, ) -> None: """ Train VAEGMM model. Parameters ---------- X Training batch. loss_fn Loss function used for training. w_recon Weight on elbo loss term if default `loss_vaegmm`. w_energy Weight on sample energy loss term if default `loss_vaegmm` loss fn is used. w_cov_diag Weight on covariance regularizing loss term if default `loss_vaegmm` loss fn is used. optimizer Optimizer used for training. cov_elbo Dictionary with covariance matrix options in case the elbo loss function is used. Either use the full covariance matrix inferred from X (dict(cov_full=None)), only the variance (dict(cov_diag=None)) or a float representing the same standard deviation for each feature (e.g. dict(sim=.05)). epochs Number of training epochs. batch_size Batch size used for training. verbose Whether to print training progress. log_metric Additional metrics whose progress will be displayed if verbose equals True. callbacks Callbacks used during training. """ # train arguments args = [self.vaegmm, loss_fn, X] kwargs = {'optimizer': optimizer, 'epochs': epochs, 'batch_size': batch_size, 'verbose': verbose, 'log_metric': log_metric, 'callbacks': callbacks, 'loss_fn_kwargs': {'w_recon': w_recon, 'w_energy': w_energy, 'w_cov_diag': w_cov_diag} } # initialize covariance matrix if default vaegmm loss fn is used use_elbo = loss_fn.__name__ == 'loss_vaegmm' cov_elbo_type, cov = [*cov_elbo][0], [*cov_elbo.values()][0] if use_elbo and cov_elbo_type in ['cov_full', 'cov_diag']: cov = tfp.stats.covariance(X.reshape(X.shape[0], -1)) if cov_elbo_type == 'cov_diag': # infer standard deviation from covariance matrix cov = tf.math.sqrt(tf.linalg.diag_part(cov)) if use_elbo: kwargs['loss_fn_kwargs'][cov_elbo_type] = tf.dtypes.cast(cov, tf.float32) # train trainer(*args, **kwargs) # set GMM parameters x_recon, z, gamma = self.vaegmm(X) self.phi, self.mu, self.cov, self.L, self.log_det_cov = gmm_params(z, gamma)