示例#1
0
def load_tf_vaegmm(filepath: str, state_dict: Dict) -> tf.keras.Model:
    """
    Load VAEGMM.

    Parameters
    ----------
    filepath
        Save directory.
    state_dict
        Dictionary containing the `n_gmm`, `latent_dim` and `recon_features` parameters.

    Returns
    -------
    Loaded VAEGMM.
    """
    model_dir = os.path.join(filepath, 'model')
    if not [f for f in os.listdir(model_dir) if not f.startswith('.')]:
        logger.warning(
            'No encoder, decoder, gmm density net or vaegmm found in {}.'.
            format(model_dir))
        return None
    encoder_net = tf.keras.models.load_model(
        os.path.join(model_dir, 'encoder_net.h5'))
    decoder_net = tf.keras.models.load_model(
        os.path.join(model_dir, 'decoder_net.h5'))
    gmm_density_net = tf.keras.models.load_model(
        os.path.join(model_dir, 'gmm_density_net.h5'))
    vaegmm = VAEGMM(encoder_net, decoder_net, gmm_density_net,
                    state_dict['n_gmm'], state_dict['latent_dim'],
                    state_dict['recon_features'], state_dict['beta'])
    vaegmm.load_weights(os.path.join(model_dir, 'vaegmm.ckpt'))
    return vaegmm

@pytest.mark.parametrize('tf_v_ae_mnist', tests, indirect=True)
def test_ae_vae(tf_v_ae_mnist):
    pass


n_gmm = 1
gmm_density_net = tf.keras.Sequential([
    InputLayer(input_shape=(latent_dim + 2, )),
    Dense(10, activation=tf.nn.relu),
    Dense(n_gmm, activation=tf.nn.softmax)
])

aegmm = AEGMM(encoder_net, decoder_net, gmm_density_net, n_gmm)
vaegmm = VAEGMM(encoder_net, decoder_net, gmm_density_net, n_gmm, latent_dim)
tests = [(aegmm, loss_aegmm), (vaegmm, loss_vaegmm)]
n_tests = len(tests)


@pytest.fixture
def tf_v_aegmm_mnist(request):
    # load and preprocess MNIST data
    (X_train, _), (X_test, _) = tf.keras.datasets.mnist.load_data()
    X = X_train.reshape(60000,
                        input_dim)[:1000]  # only train on 1000 instances
    X = X.astype(np.float32)
    X /= 255

    # init model, predict with untrained model, train and predict with trained model
    model, loss_fn = tests[request.param]
示例#3
0
    def __init__(self,
                 threshold: float = None,
                 vaegmm: tf.keras.Model = None,
                 encoder_net: tf.keras.Sequential = None,
                 decoder_net: tf.keras.Sequential = None,
                 gmm_density_net: tf.keras.Sequential = None,
                 n_gmm: int = None,
                 latent_dim: int = None,
                 samples: int = 10,
                 beta: float = 1.,
                 recon_features: Callable = eucl_cosim_features,
                 data_type: str = None
                 ) -> None:
        """
        VAEGMM-based outlier detector.

        Parameters
        ----------
        threshold
            Threshold used for outlier score to determine outliers.
        vaegmm
            A trained tf.keras model if available.
        encoder_net
            Layers for the encoder wrapped in a tf.keras.Sequential class if no 'vaegmm' is specified.
        decoder_net
            Layers for the decoder wrapped in a tf.keras.Sequential class if no 'vaegmm' is specified.
        gmm_density_net
            Layers for the GMM network wrapped in a tf.keras.Sequential class.
        n_gmm
            Number of components in GMM.
        latent_dim
            Dimensionality of the latent space.
        samples
            Number of samples sampled to evaluate each instance.
        beta
            Beta parameter for KL-divergence loss term.
        recon_features
            Function to extract features from the reconstructed instance by the decoder.
        data_type
            Optionally specifiy the data type (tabular, image or time-series). Added to metadata.
        """
        super().__init__()

        if threshold is None:
            logger.warning('No threshold level set. Need to infer threshold using `infer_threshold`.')

        self.threshold = threshold
        self.samples = samples

        # check if model can be loaded, otherwise initialize VAEGMM model
        if isinstance(vaegmm, tf.keras.Model):
            self.vaegmm = vaegmm
        elif (isinstance(encoder_net, tf.keras.Sequential) and
              isinstance(decoder_net, tf.keras.Sequential) and
              isinstance(gmm_density_net, tf.keras.Sequential)):
            self.vaegmm = VAEGMM(encoder_net, decoder_net, gmm_density_net, n_gmm,
                                 latent_dim, recon_features=recon_features, beta=beta)
        else:
            raise TypeError('No valid format detected for `vaegmm` (tf.keras.Model) '
                            'or `encoder_net`, `decoder_net` and `gmm_density_net` (tf.keras.Sequential).')

        # set metadata
        self.meta['detector_type'] = 'offline'
        self.meta['data_type'] = data_type

        self.phi, self.mu, self.cov, self.L, self.log_det_cov = None, None, None, None, None