Esempio n. 1
0
def _parse_distribution(input_shape: Tuple[int, int, int],
                        distribution: Literal['qlogistic', 'mixqlogistic',
                                              'bernoulli', 'gaussian'],
                        n_components=10,
                        n_channels=3) -> Tuple[int, DistributionLambda, Layer]:
    from odin.bay.layers import DistributionDense
    n_channels = input_shape[-1]
    last_layer = Activation('linear')
    # === 1. Quantized logistic
    if distribution == 'qlogistic':
        n_params = 2
        observation = DistributionLambda(
            lambda params: QuantizedLogistic(
                *[
                    # loc
                    p if i == 0 else
                    # Ensure scales are positive and do not collapse to near-zero
                    tf.nn.softplus(p) + tf.cast(tf.exp(-7.), tf.float32)
                    for i, p in enumerate(tf.split(params, 2, -1))
                ],
                low=0,
                high=255,
                inputs_domain='sigmoid',
                reinterpreted_batch_ndims=3),
            convert_to_tensor_fn=Distribution.sample,
            name='image')
    # === 2. Mixture Quantized logistic
    elif distribution == 'mixqlogistic':
        n_params = MixtureQuantizedLogistic.params_size(
            n_components=n_components, n_channels=n_channels) // n_channels
        observation = DistributionLambda(
            lambda params: MixtureQuantizedLogistic(params,
                                                    n_components=n_components,
                                                    n_channels=n_channels,
                                                    inputs_domain='sigmoid',
                                                    high=255,
                                                    low=0),
            convert_to_tensor_fn=Distribution.mean,
            name='image')
    # === 3. Bernoulli
    elif distribution == 'bernoulli':
        n_params = 1
        observation = DistributionDense(
            event_shape=input_shape,
            posterior=lambda p: Independent(Bernoulli(logits=p),
                                            len(input_shape)),
            projection=False,
            name="image")
    # === 4. Gaussian
    elif distribution == 'gaussian':
        n_params = 2
        observation = DistributionDense(
            event_shape=input_shape,
            posterior=lambda p: Independent(Normal(*tf.split(p, 2, -1)),
                                            len(input_shape)),
            projection=False,
            name="image")
    else:
        raise ValueError(f'No support for distribution {distribution}')
    return n_params, observation, last_layer
Esempio n. 2
0
def halfmoons_networks(qz: str = 'mvndiag',
                       zdim: Optional[int] = None,
                       activation: Union[Callable, str] = tf.nn.elu,
                       is_semi_supervised: bool = False,
                       is_hierarchical: bool = False,
                       centerize_image: bool = True,
                       skip_generator: bool = False,
                       **kwargs) -> Dict[str, Layer]:
    if zdim is None:
        zdim = 5
    networks = dsprites_networks(qz=qz,
                                 zdim=zdim,
                                 activation=activation,
                                 is_semi_supervised=False,
                                 is_hierarchical=is_hierarchical,
                                 centerize_image=centerize_image,
                                 skip_generator=skip_generator,
                                 distribution='bernoulli',
                                 n_channels=3)
    if is_semi_supervised:
        from odin.bay.layers import DistributionDense
        networks['labels'] = DistributionDense(
            event_shape=(4, ),
            posterior=_halfmoons_distribution,
            units=10,
            name='geometry3d')
    return networks
Esempio n. 3
0
 def __init__(
     self,
     n_words: int,
     n_topics: int = 20,
     posterior: Literal['gaussian', 'dirichlet'] = 'dirichlet',
     posterior_activation: Union[str, Callable[[], Tensor]] = 'softplus',
     concentration_clip: bool = True,
     distribution: Literal['onehot', 'negativebinomial', 'binomial',
                           'poisson', 'zinb'] = 'onehot',
     dropout: float = 0.0,
     dropout_strategy: Literal['all', 'warmup', 'finetune'] = 'warmup',
     batch_norm: bool = False,
     trainable_prior: bool = True,
     warmup: int = 10000,
     step: Union[int, Variable] = 0,
     input_shape: Optional[List[int]] = None,
     name: str = "Topics",
 ):
     super().__init__(name=name)
     self.n_words = int(n_words)
     self.n_topics = int(n_topics)
     self.batch_norm = bool(batch_norm)
     self.warmup = int(warmup)
     self.posterior = str(posterior).lower()
     self.distribution = str(distribution).lower()
     self.dropout = float(dropout)
     self.warmup = int(warmup)
     assert dropout_strategy in ('all', 'warmup', 'finetune'), \
       ("Support dropout strategy: all, warmup, finetune; "
        f"but given:{dropout_strategy}")
     self.dropout_strategy = str(dropout_strategy)
     if isinstance(step, Variable):
         self.step = step
     else:
         self.step = Variable(int(step),
                              dtype=tf.float32,
                              trainable=False,
                              name="Step")
     ### batch norm
     if self.batch_norm:
         self._batch_norm_layer = BatchNormalization(trainable=True)
     ### posterior
     kw = dict(event_shape=(n_topics, ), name="TopicsPosterior")
     if posterior == 'dirichlet':
         kw['posterior'] = DirichletLayer
         init_value = softplus_inverse(0.7).numpy()
         post_kw = dict(concentration_activation=posterior_activation,
                        concentration_clip=concentration_clip)
     elif posterior == "gaussian":
         kw['posterior'] = MultivariateNormalLayer
         init_value = 0.
         post_kw = dict(covariance='diag',
                        loc_activation='identity',
                        scale_activation=posterior_activation)
     else:
         raise NotImplementedError(
             "Support one of the following latent distribution: "
             "'gaussian', 'dirichlet'")
     self.topics_prior_logits = self.add_weight(
         initializer=tf.initializers.constant(value=init_value),
         shape=[1, n_topics],
         trainable=bool(trainable_prior),
         name="topics_prior_logits")
     self.posterior_layer = DistributionDense(
         posterior_kwargs=post_kw,
         prior=self.topics_prior_distribution,
         projection=True,
         **kw)
     ### output distribution
     kw = dict(event_shape=(self.n_words, ), name="WordsDistribution")
     count_activation = 'softplus'
     if self.distribution in ('onehot', ):
         self.distribution_layer = OneHotCategoricalLayer(probs_input=True,
                                                          **kw)
         self.n_parameterization = 1
     elif self.distribution in ('poisson', ):
         self.distribution_layer = PoissonLayer(**kw)
         self.n_parameterization = 1
     elif self.distribution in ('negativebinomial', 'nb'):
         self.distribution_layer = NegativeBinomialLayer(
             count_activation=count_activation, **kw)
         self.n_parameterization = 2
     elif self.distribution in ('zinb', ):
         self.distribution_layer = ZINegativeBinomialLayer(
             count_activation=count_activation, **kw)
         self.n_parameterization = 3
     elif self.distribution in ('binomial', ):
         self.distribution_layer = BinomialLayer(
             count_activation=count_activation, **kw)
         self.n_parameterization = 2
     else:
         raise ValueError(
             f"No support for word distribution: {self.distribution}")
     # topics words parameterization
     self.topics_words_params = self.add_weight(
         'topics_words_params',
         shape=[self.n_topics, self.n_words * self.n_parameterization],
         initializer=tf.initializers.glorot_normal(),
         trainable=True)
     # initialize the Model if input_shape given
     if input_shape is not None:
         self.build((None, ) + tuple(input_shape))
Esempio n. 4
0
def celeba_networks(qz: str = 'mvndiag',
                    zdim: Optional[int] = None,
                    activation: Union[Callable, str] = tf.nn.elu,
                    is_semi_supervised: bool = False,
                    is_hierarchical: bool = False,
                    centerize_image: bool = True,
                    skip_generator: bool = False,
                    n_labels: int = 18,
                    **kwargs):
    from odin.bay.random_variable import RVconf
    if zdim is None:
        zdim = 45
    input_shape = (64, 64, 3)
    n_components = 10  # for Mixture Quantized Logistic
    n_channels = input_shape[-1]
    conv, deconv = _prepare_cnn(activation=activation)
    proj_dim = 512
    encoder = SequentialNetwork(
        [
            CenterAt0(enable=centerize_image),
            conv(32, 4, strides=2, name='encoder0'),
            conv(32, 4, strides=2, name='encoder1'),
            conv(64, 4, strides=2, name='encoder2'),
            conv(64, 4, strides=1, name='encoder3'),
            keras.layers.Flatten(),
            keras.layers.Dense(
                proj_dim, activation='linear', name='encoder_proj')
        ],
        name='Encoder',
    )
    layers = [
        keras.layers.Dense(proj_dim, activation='linear', name='decoder_proj'),
        keras.layers.Reshape((8, 8, proj_dim // 64)),
        deconv(64, 4, strides=1, name='decoder1'),
        deconv(64, 4, strides=2, name='decoder2'),
        deconv(32, 4, strides=2, name='decoder3'),
        deconv(32, 4, strides=2, name='decoder4'),
        conv(
            2 * n_channels,
            # MixtureQuantizedLogistic.params_size(n_components, n_channels),
            1,
            strides=1,
            activation='linear',
            name='decoder5'),
    ]
    from odin.bay import BiConvLatents
    layers = [
        i.layer if isinstance(i, BiConvLatents) and not is_hierarchical else i
        for i in layers
    ]
    if skip_generator:
        decoder = SkipSequential(layers=layers, name='SkipDecoder')
    else:
        decoder = SequentialNetwork(layers=layers, name='Decoder')
    latents = RVconf((zdim, ), qz, projection=True,
                     name="latents").create_posterior()
    observation = _parse_distribution(input_shape, 'qlogistic')
    networks = dict(encoder=encoder,
                    decoder=decoder,
                    observation=observation,
                    latents=latents)
    if is_semi_supervised:
        from odin.bay.layers import DistributionDense
        networks['labels'] = DistributionDense(event_shape=n_labels,
                                               posterior=_celeba_distribution,
                                               units=n_labels,
                                               name='attributes')
    return networks
Esempio n. 5
0
def dsprites_networks(
    qz: str = 'mvndiag',
    zdim: Optional[int] = None,
    activation: Callable[[tf.Tensor], tf.Tensor] = tf.nn.elu,
    is_semi_supervised: bool = False,
    is_hierarchical: bool = False,
    centerize_image: bool = True,
    skip_generator: bool = False,
    **kwargs,
) -> Dict[str, Layer]:
    from odin.bay.random_variable import RVconf
    from odin.bay.vi.autoencoder import BiConvLatents
    if zdim is None:
        zdim = 10
    n_channels = int(kwargs.get('n_channels', 1))
    input_shape = (64, 64, n_channels)
    conv, deconv = _prepare_cnn(activation=activation)
    proj_dim = kwargs.get('proj_dim', None)
    if proj_dim is None:
        proj_dim = 128 if n_channels == 1 else 256
    else:
        proj_dim = int(proj_dim)
    n_params, observation, last_layer = _parse_distribution(
        input_shape, kwargs.get('distribution', 'bernoulli'))
    encoder = SequentialNetwork(
        [
            CenterAt0(enable=centerize_image),
            conv(32, 4, strides=2, name='encoder0'),
            conv(32, 4, strides=2, name='encoder1'),
            conv(64, 4, strides=2, name='encoder2'),
            conv(64, 4, strides=2, name='encoder3'),
            keras.layers.Flatten(),
            keras.layers.Dense(
                proj_dim, activation='linear', name='encoder_proj')
        ],
        name='Encoder',
    )
    # layers = [
    #   keras.layers.Dense(proj_dim, activation='linear', name='decoder_proj'),
    #   keras.layers.Reshape((4, 4, proj_dim // 16)),
    #   BiConvLatents(deconv(64, 4, strides=2, name='decoder1'),
    #                 encoder=encoder.layers[3],
    #                 filters=32, kernel_size=8, strides=4,
    #                 disable=True, name='latents1'),
    #   deconv(64, 4, strides=2, name='decoder2'),
    #   BiConvLatents(deconv(32, 4, strides=2, name='decoder3'),
    #                 encoder=encoder.layers[1],
    #                 filters=16, kernel_size=8, strides=4,
    #                 disable=True, name='latents2'),
    #   deconv(32, 4, strides=2, name='decoder4'),
    #   # NOTE: this last projection layer with linear activation is crucial
    #   # otherwise the distribution parameterized by this layer won't converge
    #   conv(n_channels * n_params,
    #        1,
    #        strides=1,
    #        activation='linear',
    #        name='decoder6'),
    #   last_layer
    # ]
    layers = [
        keras.layers.Dense(proj_dim, activation='linear', name='decoder_proj'),
        keras.layers.Reshape((4, 4, proj_dim // 16)),
        BiConvLatents(deconv(64, 4, strides=2, name='decoder1'),
                      encoder=encoder.layers[3],
                      filters=32,
                      kernel_size=8,
                      strides=4,
                      disable=True,
                      name='latents2'),
        deconv(64, 4, strides=2, name='decoder2'),
        deconv(32, 4, strides=2, name='decoder3'),
        deconv(32, 4, strides=2, name='decoder4'),
        # NOTE: this last projection layer with linear activation is crucial
        # otherwise the distribution parameterized by this layer won't converge
        conv(n_channels * n_params,
             1,
             strides=1,
             activation='linear',
             name='decoder6'),
        last_layer
    ]
    layers = [
        i.layer if isinstance(i, BiConvLatents) and not is_hierarchical else i
        for i in layers
    ]
    if skip_generator:
        decoder = SkipSequential(layers=layers, name='SkipDecoder')
    else:
        decoder = SequentialNetwork(layers=layers, name='Decoder')
    latents = RVconf((zdim, ), qz, projection=True,
                     name="latents").create_posterior()
    networks = dict(encoder=encoder,
                    decoder=decoder,
                    observation=observation,
                    latents=latents)
    if is_semi_supervised:
        from odin.bay.layers.dense_distribution import DistributionDense
        # TODO: update
        networks['labels'] = DistributionDense(
            event_shape=(5, ),
            posterior=_dsprites_distribution,
            units=9,
            name='geometry2d')
    return networks
Esempio n. 6
0
def main(cfg):
    save_to_yaml(cfg)
    if cfg.ds == 'news5':
        ds = Newsgroup5()
    elif cfg.ds == 'news20':
        ds = Newsgroup20()
    elif cfg.ds == 'news20clean':
        ds = Newsgroup20_clean()
    elif cfg.ds == 'cortex':
        ds = Cortex()
    elif cfg.ds == 'lkm':
        ds = LeukemiaATAC()
    else:
        raise NotImplementedError(f"No support for dataset: {cfg.ds}")
    train = ds.create_dataset(batch_size=batch_size,
                              partition='train',
                              drop_remainder=True)
    valid = ds.create_dataset(batch_size=batch_size, partition='valid')
    test = ds.create_dataset(batch_size=batch_size, partition='test')
    n_words = ds.vocabulary_size
    vocabulary = ds.vocabulary
    ######## prepare the path
    output_dir = get_output_dir()
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    model_path = os.path.join(output_dir, 'model')
    if cfg.override:
        clean_folder(output_dir, verbose=True)
    ######### preparing all layers
    lda = LatentDirichletDecoder(
        posterior=cfg.posterior,
        distribution=cfg.distribution,
        n_words=n_words,
        n_topics=cfg.n_topics,
        warmup=cfg.warmup,
    )
    fit_kw = dict(train=train,
                  valid=valid,
                  max_iter=cfg.n_iter,
                  optimizer='adam',
                  learning_rate=learning_rate,
                  batch_size=batch_size,
                  valid_freq=valid_freq,
                  compile_graph=True,
                  logdir=output_dir,
                  skip_fitted=True)
    output_dist = RVconf(
        n_words,
        cfg.distribution,
        projection=True,
        preactivation='softmax' if cfg.distribution == 'onehot' else 'linear',
        kwargs=dict(probs_input=True) if cfg.distribution == 'onehot' else {},
        name="Words")
    latent_dist = RVconf(cfg.n_topics,
                         'mvndiag',
                         projection=True,
                         name="Latents")
    ######## AmortizedLDA
    if cfg.model == 'lda':
        vae = AmortizedLDA(lda=lda,
                           encoder=NetConf([300, 300, 300], name='Encoder'),
                           decoder='identity',
                           latents='identity',
                           path=model_path)
        vae.fit(on_valid_end=partial(callback,
                                     vae=vae,
                                     test=test,
                                     vocabulary=vocabulary),
                **fit_kw)
    ######## VDA - Variational Dirichlet Autoencoder
    elif cfg.model == 'vda':
        vae = BetaVAE(
            beta=cfg.beta,
            encoder=NetConf([300, 150], name='Encoder'),
            decoder=NetConf([150, 300], name='Decoder'),
            latents=RVconf(cfg.n_topics,
                           'dirichlet',
                           projection=True,
                           prior=None,
                           name="Topics"),
            outputs=output_dist,
            # important, MCMC KL for Dirichlet is very unstable
            analytic=True,
            path=model_path,
            name="VDA")
        vae.fit(on_valid_end=partial(callback1,
                                     vae=vae,
                                     test=test,
                                     vocabulary=vocabulary),
                **dict(fit_kw,
                       valid_freq=1000,
                       optimizer=tf.optimizers.Adam(learning_rate=1e-4)))
    ######## VAE
    elif cfg.model == 'model':
        vae = BetaVAE(beta=cfg.beta,
                      encoder=NetConf([300, 300], name='Encoder'),
                      decoder=NetConf([300], name='Decoder'),
                      latents=latent_dist,
                      outputs=output_dist,
                      path=model_path,
                      name="VAE")
        callback1(vae, test, vocabulary)
        vae.fit(on_valid_end=partial(callback1,
                                     vae=vae,
                                     test=test,
                                     vocabulary=vocabulary),
                **dict(fit_kw,
                       valid_freq=1000,
                       optimizer=tf.optimizers.Adam(learning_rate=1e-4)))
    ######## factorVAE
    elif cfg.model == 'fvae':
        vae = FactorVAE(gamma=6.0,
                        beta=cfg.beta,
                        encoder=NetConf([300, 150], name='Encoder'),
                        decoder=NetConf([150, 300], name='Decoder'),
                        latents=latent_dist,
                        outputs=output_dist,
                        path=model_path)
        vae.fit(on_valid_end=partial(callback1,
                                     vae=vae,
                                     test=test,
                                     vocabulary=vocabulary),
                **dict(fit_kw,
                       valid_freq=1000,
                       optimizer=[
                           tf.optimizers.Adam(learning_rate=1e-4,
                                              beta_1=0.9,
                                              beta_2=0.999),
                           tf.optimizers.Adam(learning_rate=1e-4,
                                              beta_1=0.5,
                                              beta_2=0.9)
                       ]))
    ######## TwoStageLDA
    elif cfg.model == 'lda2':
        vae0_iter = 10000
        vae0 = BetaVAE(beta=1.0,
                       encoder=NetConf(units=[300], name='Encoder'),
                       decoder=NetConf(units=[300, 300], name='Decoder'),
                       outputs=DistributionDense(
                           (n_words, ),
                           posterior='onehot',
                           posterior_kwargs=dict(probs_input=True),
                           activation='softmax',
                           name="Words"),
                       latents=RVconf(cfg.n_topics,
                                      'mvndiag',
                                      projection=True,
                                      name="Latents"),
                       input_shape=(n_words, ),
                       path=model_path + '_vae0')
        vae0.fit(on_valid_end=lambda: None
                 if get_current_trainer().is_training else vae0.save_weights(),
                 **dict(fit_kw,
                        logdir=output_dir + "_vae0",
                        max_iter=vae0_iter,
                        learning_rate=learning_rate,
                        track_gradients=False))
        vae = TwoStageLDA(lda=lda,
                          encoder=vae0.encoder,
                          decoder=vae0.decoder,
                          latents=vae0.latent_layers,
                          warmup=cfg.warmup - vae0_iter,
                          path=model_path)
        vae.fit(on_valid_end=partial(callback,
                                     vae=vae,
                                     test=test,
                                     vocabulary=vocabulary),
                **dict(fit_kw,
                       max_iter=cfg.n_iter - vae0_iter,
                       track_gradients=False))
    ######## EM-LDA
    elif cfg.model == 'em':
        if os.path.exists(model_path):
            with open(model_path, 'rb') as f:
                lda = pickle.load(f)
        else:
            writer = tf.summary.create_file_writer(output_dir)
            lda = LatentDirichletAllocation(n_components=cfg.n_topics,
                                            doc_topic_prior=0.7,
                                            learning_method='online',
                                            verbose=True,
                                            n_jobs=4,
                                            random_state=1)
            with writer.as_default():
                prog = tqdm(train.repeat(-1), desc="Fitting LDA")
                for n_iter, x in enumerate(prog):
                    lda.partial_fit(x)
                    if n_iter % 500 == 0:
                        text = get_topics_text(lda.components_, vocabulary)
                        perp = lda.perplexity(test)
                        tf.summary.text("topics", text, n_iter)
                        tf.summary.scalar("perplexity", perp, n_iter)
                        prog.write(f"[#{n_iter}]Perplexity: {perp:.2f}")
                        prog.write("\n".join(text))
                    if n_iter >= 20000:
                        break
            with open(model_path, 'wb') as f:
                pickle.dump(lda, f)
        # final evaluation
        text = get_topics_text(lda, vocabulary)
        final_score = lda.perplexity(data['test'])
        tf.summary.scalar("perplexity", final_score, step=n_iter + 1)
        print(f"Perplexity:", final_score)
        print("\n".join(text))
Esempio n. 7
0
def _copy_distribution_dense(p: DistributionDense, posterior,
                             posterior_kwargs):
    init_args = dict(p._init_args)
    init_args['posterior'] = posterior
    init_args['posterior_kwargs'] = posterior_kwargs
    return DistributionDense(**init_args)
Esempio n. 8
0
 def __init__(self, name: str = 'semafod', **kwargs):
     super().__init__(name=name, **kwargs)
     labels_kw = self.labels.get_config()
     labels_kw['name'] += '_q'
     self.labels_p = self.labels
     self.labels_q = DistributionDense(**labels_kw)