Exemple #1
0
 def __init__(
     self,
     latents: LayerCreator = RVconf(64,
                                    'vdeterministic',
                                    projection=True,
                                    name="Latents"),
     dropout: float = 0.3,
     activity_l2coeff: float = 0.,
     activity_l1coeff: float = 0.,
     weights_l2coeff: float = 0.,
     weights_l1coeff: float = 0.,
     **kwargs,
 ):
     deterministic_latents = []
     for qz in as_tuple(latents):
         if isinstance(qz, RVconf):
             qz.posterior = 'vdeterministic'
         elif (isinstance(qz, DistributionDense)
               and qz.posterior != VectorDeterministicLayer):
             qz = _copy_distribution_dense(
                 qz,
                 posterior=VectorDeterministicLayer,
                 posterior_kwargs=dict(name='latents'))
         deterministic_latents.append(qz)
     if len(deterministic_latents) == 1:
         deterministic_latents = deterministic_latents[0]
     super().__init__(latents=latents, **kwargs)
     self.dropout = tf.convert_to_tensor(dropout,
                                         dtype_hint=self.dtype,
                                         name='dropout')
     self.activity_l2coeff = float(activity_l2coeff)
     self.activity_l1coeff = float(activity_l1coeff)
     self.weights_l2coeff = float(weights_l2coeff)
     self.weights_l1coeff = float(weights_l1coeff)
Exemple #2
0
 def __init__(self,
              latents: Union[RVconf, Layer] = RVconf(64, name="latents"),
              distribution: Literal['powerspherical',
                                    'vonmisesfisher'] = 'vonmisesfisher',
              prior: Union[None, SphericalUniform, VonMisesFisher,
                           PowerSpherical] = None,
              beta: Union[float, Interpolation] = linear(vmin=1e-6,
                                                         vmax=1.,
                                                         steps=2000,
                                                         delay_in=0),
              **kwargs):
     event_shape = latents.event_shape
     event_size = int(np.prod(event_shape))
     distribution = str(distribution).lower()
     assert distribution in ('powerspherical', 'vonmisesfisher'), \
       ('Support PowerSpherical or VonMisesFisher distribution, '
        f'but given: {distribution}')
     if distribution == 'powerspherical':
         fn_distribution = _power_spherical(event_size)
         default_prior = SphericalUniform(dimension=event_size)
     else:
         fn_distribution = _von_mises_fisher(event_size)
         default_prior = VonMisesFisher(0, 10)
     if prior is None:
         prior = default_prior
     latents = DistributionDense(
         event_shape,
         posterior=DistributionLambda(make_distribution_fn=fn_distribution),
         prior=prior,
         units=event_size + 1,
         name=latents.name)
     super().__init__(latents=latents, analytic=True, beta=beta, **kwargs)
Exemple #3
0
 def __init__(self, name: str = 'semafod', **kwargs):
     super().__init__(name=name, **kwargs)
     # zdim = int(np.prod(self.latents.event_shape))
     zdim = int(np.prod(self.labels.event_shape))
     self.latents_y = RVconf(
         zdim, 'mvndiag', projection=True,
         name=f'{self.latents.name}_y').create_posterior()
Exemple #4
0
 def __init__(
     self,
     observation: LayerCreator = RVconf((28, 28, 1),
                                        'bernoulli',
                                        projection=True,
                                        name='image'),
     name='Autoencoder',
     **kwargs,
 ):
     deterministic_obs = []
     for px in as_tuple(observation):
         if isinstance(px, RVconf):
             px.posterior = 'deterministic'
             px.kwargs['log_prob'] = _mse_log_prob
             px.kwargs['reinterpreted_batch_ndims'] = len(
                 observation.event_shape)
         elif (isinstance(px, DistributionDense)
               and px.posterior != DeterministicLayer):
             px = _copy_distribution_dense(px,
                                           posterior=DeterministicLayer,
                                           posterior_kwargs=dict(
                                               log_prob=_mse_log_prob,
                                               name='MSE'))
         deterministic_obs.append(px)
     if len(deterministic_obs) == 1:
         deterministic_obs = deterministic_obs[0]
     super().__init__(observation=deterministic_obs, name=name, **kwargs)
Exemple #5
0
 def __init__(
     self,
     labels: RVconf = RVconf(10, 'onehot', projection=True, name="digits"),
     encoder_y: Optional[Union[LayerCreator, Literal['tie',
                                                     'copy']]] = None,
     decoder_y: Optional[Union[LayerCreator, Literal['tie',
                                                     'copy']]] = None,
     alpha: float = 10.,
     n_semi_iw: int = (),
     skip_decoder: bool = False,
     name: str = 'MultitaskVAE',
     **kwargs,
 ):
     super().__init__(name=name, **kwargs)
     self.labels = _parse_layers(labels)
     self.labels: DistributionDense
     self.alpha = tf.convert_to_tensor(alpha,
                                       dtype=self.dtype,
                                       name='alpha')
     self.n_semi_iw = n_semi_iw
     self.skip_decoder = bool(skip_decoder)
     ## prepare encoder for Y
     if encoder_y is not None:
         units_z = sum(
             np.prod(z.event_shape if hasattr(z, 'event_shape') else z.
                     output_shape) for z in as_tuple(self.latents))
         if isinstance(encoder_y, string_types):  # copy
             if encoder_y == 'tie':
                 layers = []
             elif encoder_y == 'copy':
                 layers = [
                     keras.models.clone_model(self.encoder),
                     keras.layers.Flatten()
                 ]
             else:
                 raise ValueError(f'No support for encoder_y={encoder_y}')
         else:  # different network
             layers = [_parse_layers(encoder_y)]
         layers.append(
             RVconf(units_z, 'mvndiag', projection=True,
                    name='qzy_x').create_posterior())
         encoder_y = keras.Sequential(layers, name='encoder_y')
     self.encoder_y = encoder_y
     ## prepare decoder for Y
     if decoder_y is not None:
         decoder_y = _parse_layers(decoder_y)
     self.decoder_y = decoder_y
Exemple #6
0
 def __init__(
     self,
     labels: RVconf = RVconf(10, 'relaxedonehot', name='digits'),
     observation: RVconf = RVconf((28, 28, 1),
                                  'bernoulli',
                                  projection=True,
                                  name='image'),
     latents: RVconf = RVconf(54, 'mvndiag', projection=True, name='latents'),
     classifier: LayerCreator = NetConf([128, 128],
                                        flatten_inputs=True,
                                        name='classifier'),
     encoder: LayerCreator = NetConf([512, 512],
                                     flatten_inputs=True,
                                     name='encoder'),
     decoder: LayerCreator = NetConf([512, 512],
                                     flatten_inputs=True,
                                     name='decoder'),
     n_resamples: int = 128,
     alpha: float = 0.05,
     temperature: float = 10.,
     name: str = 'ReparameterizedM3VAE',
     **kwargs,
 ):
   super().__init__(latents=latents,
                    observation=observation,
                    encoder=encoder,
                    decoder=decoder,
                    name=name,
                    **kwargs)
   assert labels.posterior == 'relaxedonehot', \
     f"only support 'relaxedonehot' distribution for labels, given {labels.posterior}"
   self.marginalize = False
   self.n_classes = int(np.prod(labels.event_shape))
   self.n_resamples = int(n_resamples)
   self.regressor = PriorRegressor(self.n_classes)
   self.labels = RVconf(
     self.n_classes,
     posterior='relaxedonehot',
     projection=True,
     prior=OneHotCategorical(probs=[1. / self.n_classes] * self.n_classes),
     name=labels.name,
     kwargs=dict(temperature=temperature)).create_posterior()
   self.denotations = RVconf(event_shape=(self.n_classes,),
                             posterior='normal',
                             projection=True,
                             name='denotations').create_posterior()
   self.classifier = _parse_layers(classifier)
Exemple #7
0
 def __init__(self,
              latents: RVconf = RVconf(5,
                                       'mvndiag',
                                       projection=True,
                                       name='Latents'),
              factors: RVconf = RVconf(5,
                                       'mvndiag',
                                       projection=True,
                                       name="Factors"),
              **kwargs):
     latents = tf.nest.flatten(latents)
     assert isinstance(factors, RVconf), \
       "factors must be instance of RVmeta, but given: %s" % \
       str(type(factors))
     latents.append(factors)
     super().__init__(latents=latents,
                      latent_dim=int(np.prod(factors.event_shape)),
                      **kwargs)
     self.factors = factors
Exemple #8
0
 def __init__(
     self,
     encoder: List[keras.layers.Layer],
     decoder: List[keras.layers.Layer],
     layers_map: List[Tuple[str, str, int]] = (
         ('encoder2', 'decoder2', 16),
         ('encoder1', 'decoder3', 16),
         ('encoder0', 'decoder4', 16),
     ),
     beta: float = 10.,
     free_bits: float = 2.,
     name: str = 'PUnetVAE',
     **kwargs,
 ):
     encoder, decoder = _prepare_encoder_decoder(encoder, decoder)
     super().__init__(encoder=encoder,
                      decoder=decoder,
                      beta=beta,
                      name=name,
                      free_bits=free_bits,
                      **kwargs)
     encoder_name = {i.name: i for i in self.encoder}
     decoder_name = {i.name: i for i in self.decoder}
     n_latents = 0
     ladder_latents = {}
     for i, j, units in layers_map:
         if i in encoder_name and j in decoder_name:
             q = RVconf(units,
                        'mvndiag',
                        projection=True,
                        name=f'ladder_q{n_latents}').create_posterior()
             p = RVconf(units,
                        'mvndiag',
                        projection=True,
                        name=f'ladder_p{n_latents}').create_posterior()
             ladder_latents[i] = q
             ladder_latents[j] = p
             n_latents += 1
     self.ladder_latents = ladder_latents
     self.flatten = keras.layers.Flatten()
Exemple #9
0
 def __init__(
     self,
     mi_coef: Coefficient = 0.2,
     latents: RVconf = RVconf(32, 'mvndiag', projection=True, name='latents'),
     mutual_codes: RVconf = RVconf(10,
                                   'mvndiag',
                                   projection=True,
                                   name='codes'),
     steps_without_mi: int = 100,
     beta: Coefficient = linear(vmin=1e-6, vmax=1., steps=2000),
     beta_codes: Coefficient = 0.,
     name: str = 'MutualInfoVAE',
     **kwargs,
 ):
   super().__init__(beta=beta, latents=latents, name=name, **kwargs)
   self.is_binary_code = mutual_codes.is_binary
   if isinstance(mutual_codes, RVconf):
     mutual_codes = mutual_codes.create_posterior()
   self.mutual_codes = mutual_codes
   self._mi_coef = mi_coef
   self._beta_codes = beta_codes
   self.steps_without_mi = int(steps_without_mi)
Exemple #10
0
 def __init__(
     self,
     ldd: LatentDirichletDecoder,
     encoder: LayerCreator = NetConf([300, 300, 300],
                                     flatten_inputs=True,
                                     name="Encoder"),
     decoder: LayerCreator = NetConf([300, 300, 300],
                                     flatten_inputs=True,
                                     name="Decoder"),
     latents: LayerCreator = RVconf(10, 'mvndiag', True, name='Latents'),
     warmup: Optional[int] = None,
     beta: float = 1.0,
     alpha: float = 1.0,
     **kwargs,
 ):
     ...
Exemple #11
0
 def __init__(
     self,
     labels: RVconf = RVconf(10, 'onehot', projection=True, name="Labels"),
     alpha: float = 10.,
     ss_strategy: Literal['sum', 'logsumexp', 'mean', 'max',
                          'min'] = 'logsumexp',
     name: str = 'SemiFactorVAE',
     **kwargs,
 ):
     super().__init__(ss_strategy=ss_strategy,
                      labels=labels,
                      name=name,
                      **kwargs)
     self.n_labels = self.discriminator.n_outputs
     self.alpha = tf.convert_to_tensor(alpha,
                                       dtype=self.dtype,
                                       name='alpha')
Exemple #12
0
 def __init__(
     self,
     labels: RVconf = RVconf(10, 'onehot', projection=True, name="digits"),
     alpha: float = 10.0,
     mi_coef: Union[float, Interpolation] = linear(vmin=0.1,
                                                   vmax=0.05,
                                                   steps=20000),
     reverse_mi: bool = False,
     steps_without_mi: int = 1000,
     **kwargs,
 ):
     super().__init__(**kwargs)
     self._separated_steps = False
     self.labels = _parse_layers(labels)
     self._mi_coef = mi_coef
     self.alpha = alpha
     self.steps_without_mi = int(steps_without_mi)
     self.reverse_mi = bool(reverse_mi)
Exemple #13
0
 def __init__(
     self,
     batchnorm: bool = False,
     input_dropout: float = 0.,
     dropout: float = 0.,
     units: Sequence[int] = (1000, 1000, 1000, 1000, 1000),
     observation: Union[RVconf,
                        Sequence[RVconf]] = RVconf(1,
                                                   'bernoulli',
                                                   projection=True,
                                                   name="discriminator"),
     activation: Union[str, Callable[[tf.Tensor],
                                     tf.Tensor]] = tf.nn.leaky_relu,
     ss_strategy: Literal['sum', 'logsumexp', 'mean', 'max',
                          'min'] = 'logsumexp',
     name: str = "FactorDiscriminator",
 ):
     if not isinstance(observation, (tuple, list)):
         observation = [observation]
     assert len(
         observation) > 0, "No output is given for FactorDiscriminator"
     assert all(
         isinstance(o, (RVconf, DistributionDense)) for o in observation), (
             f"outputs must be instance of RVmeta, but given:{observation}")
     n_outputs = 0
     for o in observation:
         if not o.projection:
             warnings.warn(f'Projection turn off for observation {o}!')
         o.event_shape = (int(np.prod(o.event_shape)), )
         n_outputs += o.event_shape[0]
     layers = dense_network(units=units,
                            batchnorm=batchnorm,
                            dropout=dropout,
                            flatten_inputs=True,
                            input_dropout=input_dropout,
                            activation=activation,
                            prefix=name)
     super().__init__(layers, name=name)
     self.ss_strategy = str(ss_strategy)
     self.observation = observation
     self.n_outputs = n_outputs
     self._distributions = []
     assert self.ss_strategy in {'sum', 'logsumexp', 'mean', 'max', 'min'}
Exemple #14
0
 def __init__(
     self,
     labels: RVconf = RVconf(10, 'relaxedonehot', name='digits'),
     classifier: Sequence[int] = (1024, 1024, 1024, 1024),
     activation: Union[str, Callable[[Any], tf.Tensor]] = 'relu',
     alpha: float = 10.,
     **kwargs,
 ):
   super().__init__(**kwargs)
   self.alpha = float(alpha)
   ## the networks
   # TODO: force reparams here
   self.labels = _parse_layers(labels)
   self.n_classes = int(np.prod(labels.event_shape))
   self.classifier = NetConf(classifier,
                             flatten_inputs=True,
                             activation=activation,
                             name='Classifier').create_network()
   self.xy_to_qz_net = NetConf([128, 128], activation=activation,
                               name='xy_to_qz').create_network()
   self.zy_to_px_net = NetConf([128, 128], activation=activation,
                               name='zy_to_px').create_network()
   ## check the labels distribution
   self.n_classes = int(np.prod(labels.event_shape))
   # q(z|xy)
   self.y_to_qz = Dense(128, activation='linear', name='y_to_qz')
   self.x_to_qz = Dense(128, activation='linear', name='x_to_qz')
   # p(x|zy)
   self.y_to_px = Dense(128, activation='linear', name='z_to_px')
   self.z_to_px = Dense(128, activation='linear', name='z_to_px')
   self.concat = Concatenate(axis=-1)
   self.flatten = Flatten()
   self.onehot_dist = DistributionLambda(lambda p: VectorDeterministic(p))
   # classes
   if self.n_classes in (10, 3, 4):
     self.classes = [self.n_classes]
   elif self.n_classes == (15 + 8 + 4 + 30):  # dsprites
     self.classes = [15, 8, 4, 10, 10, 10]
   elif self.n_classes == (40 + 6 + 3 + 32 + 32):  # shapes3d
     self.classes = [40, 6, 3, 32, 32]
   else:
     raise NotImplementedError
Exemple #15
0
 def __init__(
     self,
     ldd: LatentDirichletDecoder,
     encoder: LayerCreator = NetConf(flatten_inputs=True, name="Encoder"),
     decoder: LayerCreator = NetConf(flatten_inputs=True, name="Decoder"),
     latents: LayerCreator = RVconf(10,
                                    posterior='mvndiag',
                                    projection=True,
                                    name="Latents"),
     **kwargs,
 ):
     super().__init__(ldd=ldd,
                      latents=latents,
                      encoder=encoder,
                      decoder=decoder,
                      **kwargs)
     # this layer won't train the KL divergence or the encoder
     self.encoder.trainable = False
     for l in self.latent_layers:
         l.trainable = False
Exemple #16
0
 def __init__(self,
              discriminator_units: Sequence[int] = (1000, 1000, 1000, 1000,
                                                    1000),
              discriminator_optim: Optional[tf.optimizers.Optimizer] = None,
              activation: Union[str, Callable[[], Any]] = tf.nn.relu,
              batchnorm: bool = False,
              tc_coef: float = 7.0,
              maximize_tc: bool = False,
              name: str = 'FactorVAE',
              **kwargs):
     ss_strategy = kwargs.pop('ss_strategy', 'logsumexp')
     labels = kwargs.pop(
         'labels',
         RVconf(1, 'bernoulli', projection=True, name="discriminator"))
     super().__init__(name=name, **kwargs)
     self.tc_coef = tf.convert_to_tensor(tc_coef,
                                         dtype=self.dtype,
                                         name='tc_coef')
     ## init discriminator
     self.discriminator = FactorDiscriminator(
         units=as_tuple(discriminator_units),
         activation=activation,
         batchnorm=batchnorm,
         ss_strategy=ss_strategy,
         observation=labels)
     if discriminator_optim is None:
         discriminator_optim = tf.optimizers.Adam(learning_rate=1e-5,
                                                  beta_1=0.5,
                                                  beta_2=0.9)
     self.discriminator_optim = discriminator_optim
     ## Discriminator and VAE must be trained separately
     self.disc_params = []
     self.vae_params = []
     self.maximize_tc = bool(maximize_tc)
     ## For training
     # store class for training factor discriminator, this allow later
     # modification without re-writing the train_steps method
     self._is_pretraining = False
Exemple #17
0
 def __init__(
     self,
     labels: RVconf = RVconf(10, 'onehot', name='digits'),
     observation: RVconf = RVconf((28, 28, 1),
                                  'bernoulli',
                                  projection=True,
                                  name='image'),
     latents: RVconf = RVconf(64, 'mvndiag', projection=True, name='latents'),
     classifier: LayerCreator = NetConf([128, 128],
                                        flatten_inputs=True,
                                        name='classifier'),
     encoder: LayerCreator = NetConf([512, 512],
                                     flatten_inputs=True,
                                     name='encoder'),
     decoder: LayerCreator = NetConf([512, 512],
                                     flatten_inputs=True,
                                     name='decoder'),
     xy_to_qz: LayerCreator = NetConf([128, 128], name='xy_to_qz'),
     zy_to_px: LayerCreator = NetConf([128, 128], name='zy_to_px'),
     embedding_dim: int = 128,
     embedding_method: Literal['repetition', 'projection', 'dictionary',
                               'sequential', 'identity'] = 'sequential',
     batchnorm: str = False,
     dropout: float = 0.,
     alpha: float = 0.05,
     beta: float = 1.,
     temperature: float = 10.,
     marginalize: bool = True,
     name: str = 'ConditionalM2VAE',
     **kwargs,
 ):
   super().__init__(latents=latents,
                    observation=observation,
                    encoder=encoder,
                    decoder=decoder,
                    beta=beta,
                    name=name,
                    **kwargs)
   self.alpha = tf.convert_to_tensor(alpha, dtype=self.dtype, name="alpha")
   self.embedding_dim = int(embedding_dim)
   self.embedding_method = str(embedding_method)
   self.batchnorm = bool(batchnorm)
   self.dropout = float(dropout)
   ## the networks
   self.classifier = _parse_layers(classifier)
   self.xy_to_qz_net = _parse_layers(xy_to_qz)
   self.zy_to_px_net = _parse_layers(zy_to_px)
   ## check the labels distribution
   if hasattr(labels, 'posterior'):
     posterior_name = str(labels.posterior)
   if hasattr(labels, 'posterior_layer'):
     posterior_name = str(labels.posterior_layer).lower()
   if 'onehot' not in posterior_name:
     warnings.warn(
       'Conditional VAE only support one-hot or relaxed one-hot distribution, '
       f'but given: {labels}')
   self.n_classes = int(np.prod(labels.event_shape))
   self.marginalize = bool(marginalize)
   # labels distribution
   if marginalize:
     temperature = 0
   if temperature == 0.:
     posterior = 'onehot'
     dist_kw = dict()
     self.relaxed = False
   else:
     posterior = 'relaxedonehot'
     dist_kw = dict(temperature=temperature)
     self.relaxed = True
   self.labels = RVconf(self.n_classes,
                        posterior,
                        projection=True,
                        prior=OneHotCategorical(probs=[1. / self.n_classes] *
                                                      self.n_classes),
                        name=labels.name,
                        kwargs=dist_kw).create_posterior()
   # create embedder
   embedder = get_embedding(self.embedding_method)
   # q(z|xy)
   self.y_to_qz = embedder(n_classes=self.n_classes,
                           event_shape=self.embedding_dim,
                           name='y_to_qz')
   self.x_to_qz = Dense(embedding_dim, activation='linear', name='x_to_qz')
   # p(x|zy)
   self.y_to_px = embedder(n_classes=self.n_classes,
                           event_shape=self.embedding_dim,
                           name='y_to_px')
   self.z_to_px = Dense(embedding_dim, activation='linear', name='z_to_px')
   # batch normalization
   if self.batchnorm:
     self.qz_xy_norm = BatchNormalization(axis=-1, name='qz_xy_norm')
     self.px_zy_norm = BatchNormalization(axis=-1, name='px_zy_norm')
   if 0.0 < self.dropout < 1.0:
     self.qz_xy_drop = Dropout(rate=self.dropout, name='qz_xy_drop')
     self.px_zy_drop = Dropout(rate=self.dropout, name='px_zy_drop')
Exemple #18
0
def mnist_networks(
    qz: str = 'mvndiag',
    zdim: Optional[int] = None,
    activation: Callable[[tf.Tensor], tf.Tensor] = tf.nn.elu,
    is_semi_supervised: bool = False,
    is_hierarchical: bool = False,
    centerize_image: bool = True,
    skip_generator: bool = False,
    **kwargs,
) -> Dict[str, Layer]:
    """Network for MNIST dataset image size (28, 28, 1)"""
    from odin.bay.random_variable import RVconf
    from odin.bay.vi import BiConvLatents
    n_channels = int(kwargs.get('n_channels', 1))
    proj_dim = 196
    input_shape = (28, 28, n_channels)
    if zdim is None:
        zdim = 32
    conv, deconv = _prepare_cnn(activation=activation)
    n_params, observation, last_layer = _parse_distribution(
        input_shape, kwargs.get('distribution', 'bernoulli'))
    encoder = SequentialNetwork(
        [
            keras.layers.InputLayer(input_shape),
            CenterAt0(enable=centerize_image),
            conv(32, 5, strides=1, name='encoder0'),  # 28, 28, 32
            conv(32, 5, strides=2, name='encoder1'),  # 14, 14, 32
            conv(64, 5, strides=1, name='encoder2'),  # 14, 14, 64
            conv(64, 5, strides=2, name='encoder3'),  # 7 , 7 , 64
            keras.layers.Flatten(),
            keras.layers.Dense(
                proj_dim, activation='linear', name='encoder_proj')
        ],
        name='Encoder',
    )
    layers = [
        keras.layers.Dense(proj_dim, activation='linear', name='decoder_proj'),
        keras.layers.Reshape((7, 7, proj_dim // 49)),  # 7, 7, 4
        deconv(64, 5, strides=2, name='decoder2'),  # 14, 14, 64
        BiConvLatents(
            conv(64, 5, strides=1, name='decoder3'),  # 14, 14, 64
            encoder=encoder.layers[3],
            filters=16,
            kernel_size=14,
            strides=7,
            disable=True,
            name='latents2'),
        deconv(32, 5, strides=2, name='decoder4'),  # 28, 28, 32
        conv(32, 5, strides=1, name='decoder5'),  # 28, 28, 32
        conv(n_channels * n_params,
             1,
             strides=1,
             activation='linear',
             name='decoder6'),
        last_layer
    ]
    layers = [
        i.layer if isinstance(i, BiConvLatents) and not is_hierarchical else i
        for i in layers
    ]
    if skip_generator:
        decoder = SkipSequential(layers=layers, name='SkipDecoder')
    else:
        decoder = SequentialNetwork(layers=layers, name='Decoder')
    latents = RVconf((zdim, ), qz, projection=True,
                     name="latents").create_posterior()
    networks = dict(encoder=encoder,
                    decoder=decoder,
                    observation=observation,
                    latents=latents)
    if is_semi_supervised:
        networks['labels'] = RVconf(
            10,
            'onehot',
            projection=True,
            name=kwargs.get('labels_name', 'digits'),
        ).create_posterior()
    return networks
Exemple #19
0
 def __init__(self,
              latents=RVconf(5, 'mvndiag', projection=True, name='Latents'),
              factors=RVconf(5, 'mvndiag', projection=True, name='Factors'),
              **kwargs):
     super().__init__(latents=latents, factors=factors, **kwargs)
 def __init__(
     self,
     observation: LayerCreator = RVconf((28, 28, 1),
                                        'bernoulli',
                                        projection=True,
                                        name='image'),
     latents: Optional[LayerCreator] = RVconf(16,
                                              'mvndiag',
                                              projection=True,
                                              name="latents"),
     encoder: Optional[LayerCreator] = None,
     decoder: Optional[LayerCreator] = None,
     **kwargs,
 ):
     if encoder is None:
         encoder = NetConf((512, 512), flatten_inputs=True, name="encoder")
     if decoder is None:
         decoder = NetConf((512, 512),
                           flatten_inputs=True,
                           flatten_outputs=True,
                           name="decoder")
     ### keras want this supports_masking on to enable support masking
     super().__init__(**kwargs)
     ### create layers
     # encoder
     if isinstance(encoder, (tuple, list)):
         self._encoder = [
             _parse_layers(network=e, name=f"encoder{i}")
             for i, e in enumerate(encoder)
         ]
         self._encoder_args = [_get_args(e) for e in self._encoder]
     else:
         self._encoder = _parse_layers(network=encoder, name="encoder")
         self._encoder_args = _get_args(self.encoder)
     # latents
     if isinstance(latents, (tuple, list)):
         self._latents = [
             _parse_layers(network=z, name=f"latents{i}")
             for i, z in enumerate(latents)
         ]
         self._latents_args = [_get_args(z) for z in self.latents]
     else:
         self._latents = _parse_layers(network=latents, name="latents")
         self._latents_args = _get_args(self.latents)
     # decoder
     if isinstance(decoder, (tuple, list)):
         self._decoder = [
             _parse_layers(network=d, name=f"decoder{i}")
             for i, d in enumerate(decoder)
         ]
         self._decoder_args = [_get_args(d) for d in self.decoder]
     else:
         self._decoder = _parse_layers(network=decoder, name="decoder")
         self._decoder_args = _get_args(self.decoder)
     # observation
     if isinstance(observation, (tuple, list)):
         self._observation = [
             _parse_layers(network=observation, name=f"observation{i}")
             for i, o in enumerate(observation)
         ]
         self._observation_args = [_get_args(o) for o in self.observation]
     else:
         self._observation = _parse_layers(network=observation,
                                           name="observation")
         self._observation_args = _get_args(self.observation)
Exemple #21
0
def dsprites_networks(
    qz: str = 'mvndiag',
    zdim: Optional[int] = None,
    activation: Callable[[tf.Tensor], tf.Tensor] = tf.nn.elu,
    is_semi_supervised: bool = False,
    is_hierarchical: bool = False,
    centerize_image: bool = True,
    skip_generator: bool = False,
    **kwargs,
) -> Dict[str, Layer]:
    from odin.bay.random_variable import RVconf
    from odin.bay.vi.autoencoder import BiConvLatents
    if zdim is None:
        zdim = 10
    n_channels = int(kwargs.get('n_channels', 1))
    input_shape = (64, 64, n_channels)
    conv, deconv = _prepare_cnn(activation=activation)
    proj_dim = kwargs.get('proj_dim', None)
    if proj_dim is None:
        proj_dim = 128 if n_channels == 1 else 256
    else:
        proj_dim = int(proj_dim)
    n_params, observation, last_layer = _parse_distribution(
        input_shape, kwargs.get('distribution', 'bernoulli'))
    encoder = SequentialNetwork(
        [
            CenterAt0(enable=centerize_image),
            conv(32, 4, strides=2, name='encoder0'),
            conv(32, 4, strides=2, name='encoder1'),
            conv(64, 4, strides=2, name='encoder2'),
            conv(64, 4, strides=2, name='encoder3'),
            keras.layers.Flatten(),
            keras.layers.Dense(
                proj_dim, activation='linear', name='encoder_proj')
        ],
        name='Encoder',
    )
    # layers = [
    #   keras.layers.Dense(proj_dim, activation='linear', name='decoder_proj'),
    #   keras.layers.Reshape((4, 4, proj_dim // 16)),
    #   BiConvLatents(deconv(64, 4, strides=2, name='decoder1'),
    #                 encoder=encoder.layers[3],
    #                 filters=32, kernel_size=8, strides=4,
    #                 disable=True, name='latents1'),
    #   deconv(64, 4, strides=2, name='decoder2'),
    #   BiConvLatents(deconv(32, 4, strides=2, name='decoder3'),
    #                 encoder=encoder.layers[1],
    #                 filters=16, kernel_size=8, strides=4,
    #                 disable=True, name='latents2'),
    #   deconv(32, 4, strides=2, name='decoder4'),
    #   # NOTE: this last projection layer with linear activation is crucial
    #   # otherwise the distribution parameterized by this layer won't converge
    #   conv(n_channels * n_params,
    #        1,
    #        strides=1,
    #        activation='linear',
    #        name='decoder6'),
    #   last_layer
    # ]
    layers = [
        keras.layers.Dense(proj_dim, activation='linear', name='decoder_proj'),
        keras.layers.Reshape((4, 4, proj_dim // 16)),
        BiConvLatents(deconv(64, 4, strides=2, name='decoder1'),
                      encoder=encoder.layers[3],
                      filters=32,
                      kernel_size=8,
                      strides=4,
                      disable=True,
                      name='latents2'),
        deconv(64, 4, strides=2, name='decoder2'),
        deconv(32, 4, strides=2, name='decoder3'),
        deconv(32, 4, strides=2, name='decoder4'),
        # NOTE: this last projection layer with linear activation is crucial
        # otherwise the distribution parameterized by this layer won't converge
        conv(n_channels * n_params,
             1,
             strides=1,
             activation='linear',
             name='decoder6'),
        last_layer
    ]
    layers = [
        i.layer if isinstance(i, BiConvLatents) and not is_hierarchical else i
        for i in layers
    ]
    if skip_generator:
        decoder = SkipSequential(layers=layers, name='SkipDecoder')
    else:
        decoder = SequentialNetwork(layers=layers, name='Decoder')
    latents = RVconf((zdim, ), qz, projection=True,
                     name="latents").create_posterior()
    networks = dict(encoder=encoder,
                    decoder=decoder,
                    observation=observation,
                    latents=latents)
    if is_semi_supervised:
        from odin.bay.layers.dense_distribution import DistributionDense
        # TODO: update
        networks['labels'] = DistributionDense(
            event_shape=(5, ),
            posterior=_dsprites_distribution,
            units=9,
            name='geometry2d')
    return networks
Exemple #22
0
def celeba_networks(qz: str = 'mvndiag',
                    zdim: Optional[int] = None,
                    activation: Union[Callable, str] = tf.nn.elu,
                    is_semi_supervised: bool = False,
                    is_hierarchical: bool = False,
                    centerize_image: bool = True,
                    skip_generator: bool = False,
                    n_labels: int = 18,
                    **kwargs):
    from odin.bay.random_variable import RVconf
    if zdim is None:
        zdim = 45
    input_shape = (64, 64, 3)
    n_components = 10  # for Mixture Quantized Logistic
    n_channels = input_shape[-1]
    conv, deconv = _prepare_cnn(activation=activation)
    proj_dim = 512
    encoder = SequentialNetwork(
        [
            CenterAt0(enable=centerize_image),
            conv(32, 4, strides=2, name='encoder0'),
            conv(32, 4, strides=2, name='encoder1'),
            conv(64, 4, strides=2, name='encoder2'),
            conv(64, 4, strides=1, name='encoder3'),
            keras.layers.Flatten(),
            keras.layers.Dense(
                proj_dim, activation='linear', name='encoder_proj')
        ],
        name='Encoder',
    )
    layers = [
        keras.layers.Dense(proj_dim, activation='linear', name='decoder_proj'),
        keras.layers.Reshape((8, 8, proj_dim // 64)),
        deconv(64, 4, strides=1, name='decoder1'),
        deconv(64, 4, strides=2, name='decoder2'),
        deconv(32, 4, strides=2, name='decoder3'),
        deconv(32, 4, strides=2, name='decoder4'),
        conv(
            2 * n_channels,
            # MixtureQuantizedLogistic.params_size(n_components, n_channels),
            1,
            strides=1,
            activation='linear',
            name='decoder5'),
    ]
    from odin.bay import BiConvLatents
    layers = [
        i.layer if isinstance(i, BiConvLatents) and not is_hierarchical else i
        for i in layers
    ]
    if skip_generator:
        decoder = SkipSequential(layers=layers, name='SkipDecoder')
    else:
        decoder = SequentialNetwork(layers=layers, name='Decoder')
    latents = RVconf((zdim, ), qz, projection=True,
                     name="latents").create_posterior()
    observation = _parse_distribution(input_shape, 'qlogistic')
    networks = dict(encoder=encoder,
                    decoder=decoder,
                    observation=observation,
                    latents=latents)
    if is_semi_supervised:
        from odin.bay.layers import DistributionDense
        networks['labels'] = DistributionDense(event_shape=n_labels,
                                               posterior=_celeba_distribution,
                                               units=n_labels,
                                               name='attributes')
    return networks
Exemple #23
0
def pbmc_networks(
        qz: str = 'mvndiag',
        zdim: Optional[int] = 32,
        activation: Callable[[tf.Tensor], tf.Tensor] = tf.nn.elu,
        is_semi_supervised: bool = False,
        is_hierarchical: bool = False,
        log_norm: bool = True,
        cnn: bool = True,
        units: Sequence[int] = (512, 512, 512),
        **kwargs,
) -> Dict[str, Layer]:
    """Network for Cortex mRNA sequencing datasets"""
    from odin.bay.random_variable import RVconf
    input_shape = (2019, )
    n_labels = 32
    if zdim is None:
        zdim = 32
    ## dense network
    if not cnn:
        encoder = SequentialNetwork(
            [LogNorm(enable=log_norm)] + [
                keras.layers.Dense(
                    u, activation=activation, name=f'encoder{i}')
                for i, u in enumerate(units)
            ],
            name='encoder',
        )
        decoder = SequentialNetwork(
            [
                keras.layers.Dense(
                    u, activation=activation, name=f'decoder{i}')
                for i, u in enumerate(units)
            ],
            name='decoder',
        )
    ## conv network
    else:
        Conv1D = partial(keras.layers.Conv1D,
                         strides=2,
                         padding='same',
                         activation=activation)
        Conv1DTranspose = partial(keras.layers.Conv1DTranspose,
                                  strides=2,
                                  padding='same',
                                  activation=activation)
        encoder = SequentialNetwork(
            [
                LogNorm(enable=log_norm),
                keras.layers.Lambda(
                    lambda x: tf.expand_dims(x, axis=-1)),  # (n, 2019, 1)
                Conv1D(32, 7, name='encoder0'),
                Conv1D(64, 5, name='encoder1'),
                Conv1D(128, 3, name='encoder2'),
                Conv1D(128, 3, name='encoder3'),
                keras.layers.Flatten()
            ],
            name='encoder',
        )
        decoder = SequentialNetwork(
            [
                keras.layers.Dense(256, activation=activation,
                                   name='decoder0'),
                keras.layers.Lambda(
                    lambda x: tf.expand_dims(x, axis=-1)),  # (n, 256, 1)
                Conv1DTranspose(128, 3, strides=1, name='decoder1'),
                Conv1DTranspose(128, 3, name='decoder2'),
                Conv1DTranspose(64, 5, name='decoder3'),
                Conv1DTranspose(32, 7, name='decoder4'),
                Conv1DTranspose(1, 1, strides=1, name='decoder5'),
                keras.layers.Flatten()
            ],
            name='decoder',
        )
    latents = RVconf((zdim, ), qz, projection=True,
                     name="latents").create_posterior()
    observation = RVconf(input_shape, "zinb", projection=True,
                         name="mrna").create_posterior()
    networks = dict(encoder=encoder,
                    decoder=decoder,
                    observation=observation,
                    latents=latents)
    if is_semi_supervised:
        networks['labels'] = RVconf(n_labels,
                                    'nb',
                                    projection=True,
                                    name='adt').create_posterior()
    return networks
Exemple #24
0
 def __init__(
     self,
     n_classes: int = 10,
     observation=RVconf((28, 28, 1),
                        'bernoulli',
                        projection=True,
                        name='image'),
     latents: RVconf = RVconf(64, 'mvndiag', projection=True, name='latents'),
     classifier: LayerCreator = NetConf([128, 128],
                                              flatten_inputs=True,
                                              name='classifier'),
     auxiliary: RVconf = RVconf(64,
                                'mvndiag',
                                projection=True,
                                name='auxiliary'),
     encoder_a: LayerCreator = NetConf([512, 512],
                                             flatten_inputs=True,
                                             name='encoder_a'),
     decoder_a: LayerCreator = NetConf([512, 512],
                                             flatten_inputs=True,
                                             name='decoder_a'),
     encoder: LayerCreator = NetConf([512, 512],
                                           flatten_inputs=True,
                                           name='encoder'),
     decoder: LayerCreator = NetConf([512, 512],
                                           flatten_inputs=True,
                                           name='decoder'),
     axy_to_qz: LayerCreator = NetConf([128, 128], name='axy_to_qz'),
     azy_to_px: LayerCreator = NetConf([128, 128], name='azy_to_px'),
     embedding_dim: int = 128,
     embedding_method: Literal['repetition', 'projection', 'dictionary',
                               'sequential', 'identity'] = 'sequential',
     batchnorm: bool = False,
     dropout: float = 0.,
     skip_connection: bool = True,
     alpha: float = 1.0,
     beta: float = 1.0,
     temperature: float = 10.,
     marginalize: bool = True,
     name='AuxiliaryVAE',
     **kwargs,
 ):
   super().__init__(n_classes=n_classes,
                    observation=observation,
                    latents=latents,
                    classifier=classifier,
                    encoder=encoder,
                    decoder=decoder,
                    xy_to_qz=axy_to_qz,
                    zy_to_px=azy_to_px,
                    embedding_dim=embedding_dim,
                    embedding_method=embedding_method,
                    batchnorm=batchnorm,
                    dropout=dropout,
                    alpha=alpha,
                    beta=beta,
                    temperature=temperature,
                    marginalize=marginalize,
                    name=name,
                    **kwargs)
   self.skip_connection = bool(skip_connection)
   self.batchnorm = bool(batchnorm)
   self.qa_dist = auxiliary.create_posterior(name='qa_x')
   self.pa_dist = auxiliary.create_posterior(name='pa_xz')
   self.encoder_a = _parse_layers(encoder_a)
   self.decoder_a = _parse_layers(decoder_a)
   # labels connections
   self.x_to_qy = Dense(units=self.embedding_dim, activation='linear')
   self.a_to_qy = Dense(units=self.embedding_dim, activation='linear')
   # auxiliary connections
   self.a_to_qz = Dense(units=self.embedding_dim, activation='linear')
   self.a_to_px = Dense(units=self.embedding_dim, activation='linear')
   # for p(a|yz)
   self.y_to_pa = Dense(units=self.embedding_dim, activation='linear')
   self.z_to_pa = Dense(units=self.embedding_dim, activation='linear')
   # batchnorm and dropout
   if self.batchnorm:
     self.qy_ax_norm = BatchNormalization(axis=-1, name='qy_ax_norm')
     self.pa_zy_norm = BatchNormalization(axis=-1, name='pa_zy_norm')
   if 0.0 < self.dropout < 1.0:
     self.qy_ax_drop = Dropout(rate=self.dropout, name='qy_ax_drop')
     self.pa_zy_drop = Dropout(axis=-1, name='pa_zy_drop')
Exemple #25
0
def cifar_networks(
    qz: str = 'mvndiag',
    zdim: Optional[int] = None,
    activation: Callable[[tf.Tensor], tf.Tensor] = tf.nn.elu,
    is_semi_supervised: bool = False,
    is_hierarchical: bool = False,
    centerize_image: bool = True,
    skip_generator: bool = False,
    **kwargs,
) -> Dict[str, Layer]:
    """Network for CIFAR dataset image size (32, 32, 3)"""
    from odin.bay.random_variable import RVconf
    from odin.bay.vi.autoencoder.hierarchical_vae import BiConvLatents
    if zdim is None:
        zdim = 256
    n_channels = kwargs.get('n_channels', 3)
    input_shape = (32, 32, n_channels)
    conv, deconv = _prepare_cnn(activation=activation)
    n_classes = kwargs.get('n_classes', 10)
    proj_dim = 8 * 8 * 8
    ## output distribution
    n_params, observation, last_layer = _parse_distribution(
        input_shape, kwargs.get('distribution', 'qlogistic'))
    ## encoder
    encoder = SequentialNetwork(
        [
            CenterAt0(enable=centerize_image),
            conv(32, 4, strides=1, name='encoder0'),  # 32, 32, 32
            conv(32, 4, strides=2, name='encoder1'),  # 16, 16, 32
            conv(64, 4, strides=1, name='encoder2'),  # 16, 16, 64
            conv(64, 4, strides=2, name='encoder3'),  # 8, 8, 64
            keras.layers.Flatten(),
            keras.layers.Dense(
                proj_dim, activation='linear', name='encoder_proj')
        ],
        name='Encoder',
    )
    layers = [
        keras.layers.Dense(proj_dim, activation='linear', name='decoder_proj'),
        keras.layers.Reshape((8, 8, proj_dim // 64)),  # 8, 8, 4
        deconv(64, 4, strides=2, name='decoder1'),  # 16, 16, 64
        BiConvLatents(
            conv(64, 4, strides=1, name='decoder2'),  # 16, 16, 64
            encoder=encoder.layers[3],
            filters=32,
            kernel_size=8,
            strides=4,
            disable=True,
            name='latents1'),
        deconv(32, 4, strides=2, name='decoder3'),  # 32, 32, 32
        BiConvLatents(
            conv(32, 4, strides=1, name='decoder4'),  # 32, 32, 32
            encoder=encoder.layers[1],
            filters=16,
            kernel_size=8,
            strides=4,
            disable=True,
            name='latents2'),
        conv(
            n_channels * n_params,  # 32, 32, 3
            1,
            strides=1,
            activation='linear',
            name='decoder5'),
        last_layer
    ]
    layers = [
        i.layer if isinstance(i, BiConvLatents) and not is_hierarchical else i
        for i in layers
    ]
    if skip_generator:
        decoder = SkipSequential(layers=layers, name='SkipDecoder')
    else:
        decoder = SequentialNetwork(layers=layers, name='Decoder')
    ## others
    latents = RVconf((zdim, ), qz, projection=True,
                     name="latents").create_posterior()
    # create the observation of MixtureQuantizedLogistic
    networks = dict(encoder=encoder,
                    decoder=decoder,
                    observation=observation,
                    latents=latents)
    if is_semi_supervised:
        networks['labels'] = RVconf(n_classes,
                                    'onehot',
                                    projection=True,
                                    name='labels').create_posterior()
    return networks