def __init__( self, latents: LayerCreator = RVconf(64, 'vdeterministic', projection=True, name="Latents"), dropout: float = 0.3, activity_l2coeff: float = 0., activity_l1coeff: float = 0., weights_l2coeff: float = 0., weights_l1coeff: float = 0., **kwargs, ): deterministic_latents = [] for qz in as_tuple(latents): if isinstance(qz, RVconf): qz.posterior = 'vdeterministic' elif (isinstance(qz, DistributionDense) and qz.posterior != VectorDeterministicLayer): qz = _copy_distribution_dense( qz, posterior=VectorDeterministicLayer, posterior_kwargs=dict(name='latents')) deterministic_latents.append(qz) if len(deterministic_latents) == 1: deterministic_latents = deterministic_latents[0] super().__init__(latents=latents, **kwargs) self.dropout = tf.convert_to_tensor(dropout, dtype_hint=self.dtype, name='dropout') self.activity_l2coeff = float(activity_l2coeff) self.activity_l1coeff = float(activity_l1coeff) self.weights_l2coeff = float(weights_l2coeff) self.weights_l1coeff = float(weights_l1coeff)
def __init__(self, latents: Union[RVconf, Layer] = RVconf(64, name="latents"), distribution: Literal['powerspherical', 'vonmisesfisher'] = 'vonmisesfisher', prior: Union[None, SphericalUniform, VonMisesFisher, PowerSpherical] = None, beta: Union[float, Interpolation] = linear(vmin=1e-6, vmax=1., steps=2000, delay_in=0), **kwargs): event_shape = latents.event_shape event_size = int(np.prod(event_shape)) distribution = str(distribution).lower() assert distribution in ('powerspherical', 'vonmisesfisher'), \ ('Support PowerSpherical or VonMisesFisher distribution, ' f'but given: {distribution}') if distribution == 'powerspherical': fn_distribution = _power_spherical(event_size) default_prior = SphericalUniform(dimension=event_size) else: fn_distribution = _von_mises_fisher(event_size) default_prior = VonMisesFisher(0, 10) if prior is None: prior = default_prior latents = DistributionDense( event_shape, posterior=DistributionLambda(make_distribution_fn=fn_distribution), prior=prior, units=event_size + 1, name=latents.name) super().__init__(latents=latents, analytic=True, beta=beta, **kwargs)
def __init__(self, name: str = 'semafod', **kwargs): super().__init__(name=name, **kwargs) # zdim = int(np.prod(self.latents.event_shape)) zdim = int(np.prod(self.labels.event_shape)) self.latents_y = RVconf( zdim, 'mvndiag', projection=True, name=f'{self.latents.name}_y').create_posterior()
def __init__( self, observation: LayerCreator = RVconf((28, 28, 1), 'bernoulli', projection=True, name='image'), name='Autoencoder', **kwargs, ): deterministic_obs = [] for px in as_tuple(observation): if isinstance(px, RVconf): px.posterior = 'deterministic' px.kwargs['log_prob'] = _mse_log_prob px.kwargs['reinterpreted_batch_ndims'] = len( observation.event_shape) elif (isinstance(px, DistributionDense) and px.posterior != DeterministicLayer): px = _copy_distribution_dense(px, posterior=DeterministicLayer, posterior_kwargs=dict( log_prob=_mse_log_prob, name='MSE')) deterministic_obs.append(px) if len(deterministic_obs) == 1: deterministic_obs = deterministic_obs[0] super().__init__(observation=deterministic_obs, name=name, **kwargs)
def __init__( self, labels: RVconf = RVconf(10, 'onehot', projection=True, name="digits"), encoder_y: Optional[Union[LayerCreator, Literal['tie', 'copy']]] = None, decoder_y: Optional[Union[LayerCreator, Literal['tie', 'copy']]] = None, alpha: float = 10., n_semi_iw: int = (), skip_decoder: bool = False, name: str = 'MultitaskVAE', **kwargs, ): super().__init__(name=name, **kwargs) self.labels = _parse_layers(labels) self.labels: DistributionDense self.alpha = tf.convert_to_tensor(alpha, dtype=self.dtype, name='alpha') self.n_semi_iw = n_semi_iw self.skip_decoder = bool(skip_decoder) ## prepare encoder for Y if encoder_y is not None: units_z = sum( np.prod(z.event_shape if hasattr(z, 'event_shape') else z. output_shape) for z in as_tuple(self.latents)) if isinstance(encoder_y, string_types): # copy if encoder_y == 'tie': layers = [] elif encoder_y == 'copy': layers = [ keras.models.clone_model(self.encoder), keras.layers.Flatten() ] else: raise ValueError(f'No support for encoder_y={encoder_y}') else: # different network layers = [_parse_layers(encoder_y)] layers.append( RVconf(units_z, 'mvndiag', projection=True, name='qzy_x').create_posterior()) encoder_y = keras.Sequential(layers, name='encoder_y') self.encoder_y = encoder_y ## prepare decoder for Y if decoder_y is not None: decoder_y = _parse_layers(decoder_y) self.decoder_y = decoder_y
def __init__( self, labels: RVconf = RVconf(10, 'relaxedonehot', name='digits'), observation: RVconf = RVconf((28, 28, 1), 'bernoulli', projection=True, name='image'), latents: RVconf = RVconf(54, 'mvndiag', projection=True, name='latents'), classifier: LayerCreator = NetConf([128, 128], flatten_inputs=True, name='classifier'), encoder: LayerCreator = NetConf([512, 512], flatten_inputs=True, name='encoder'), decoder: LayerCreator = NetConf([512, 512], flatten_inputs=True, name='decoder'), n_resamples: int = 128, alpha: float = 0.05, temperature: float = 10., name: str = 'ReparameterizedM3VAE', **kwargs, ): super().__init__(latents=latents, observation=observation, encoder=encoder, decoder=decoder, name=name, **kwargs) assert labels.posterior == 'relaxedonehot', \ f"only support 'relaxedonehot' distribution for labels, given {labels.posterior}" self.marginalize = False self.n_classes = int(np.prod(labels.event_shape)) self.n_resamples = int(n_resamples) self.regressor = PriorRegressor(self.n_classes) self.labels = RVconf( self.n_classes, posterior='relaxedonehot', projection=True, prior=OneHotCategorical(probs=[1. / self.n_classes] * self.n_classes), name=labels.name, kwargs=dict(temperature=temperature)).create_posterior() self.denotations = RVconf(event_shape=(self.n_classes,), posterior='normal', projection=True, name='denotations').create_posterior() self.classifier = _parse_layers(classifier)
def __init__(self, latents: RVconf = RVconf(5, 'mvndiag', projection=True, name='Latents'), factors: RVconf = RVconf(5, 'mvndiag', projection=True, name="Factors"), **kwargs): latents = tf.nest.flatten(latents) assert isinstance(factors, RVconf), \ "factors must be instance of RVmeta, but given: %s" % \ str(type(factors)) latents.append(factors) super().__init__(latents=latents, latent_dim=int(np.prod(factors.event_shape)), **kwargs) self.factors = factors
def __init__( self, encoder: List[keras.layers.Layer], decoder: List[keras.layers.Layer], layers_map: List[Tuple[str, str, int]] = ( ('encoder2', 'decoder2', 16), ('encoder1', 'decoder3', 16), ('encoder0', 'decoder4', 16), ), beta: float = 10., free_bits: float = 2., name: str = 'PUnetVAE', **kwargs, ): encoder, decoder = _prepare_encoder_decoder(encoder, decoder) super().__init__(encoder=encoder, decoder=decoder, beta=beta, name=name, free_bits=free_bits, **kwargs) encoder_name = {i.name: i for i in self.encoder} decoder_name = {i.name: i for i in self.decoder} n_latents = 0 ladder_latents = {} for i, j, units in layers_map: if i in encoder_name and j in decoder_name: q = RVconf(units, 'mvndiag', projection=True, name=f'ladder_q{n_latents}').create_posterior() p = RVconf(units, 'mvndiag', projection=True, name=f'ladder_p{n_latents}').create_posterior() ladder_latents[i] = q ladder_latents[j] = p n_latents += 1 self.ladder_latents = ladder_latents self.flatten = keras.layers.Flatten()
def __init__( self, mi_coef: Coefficient = 0.2, latents: RVconf = RVconf(32, 'mvndiag', projection=True, name='latents'), mutual_codes: RVconf = RVconf(10, 'mvndiag', projection=True, name='codes'), steps_without_mi: int = 100, beta: Coefficient = linear(vmin=1e-6, vmax=1., steps=2000), beta_codes: Coefficient = 0., name: str = 'MutualInfoVAE', **kwargs, ): super().__init__(beta=beta, latents=latents, name=name, **kwargs) self.is_binary_code = mutual_codes.is_binary if isinstance(mutual_codes, RVconf): mutual_codes = mutual_codes.create_posterior() self.mutual_codes = mutual_codes self._mi_coef = mi_coef self._beta_codes = beta_codes self.steps_without_mi = int(steps_without_mi)
def __init__( self, ldd: LatentDirichletDecoder, encoder: LayerCreator = NetConf([300, 300, 300], flatten_inputs=True, name="Encoder"), decoder: LayerCreator = NetConf([300, 300, 300], flatten_inputs=True, name="Decoder"), latents: LayerCreator = RVconf(10, 'mvndiag', True, name='Latents'), warmup: Optional[int] = None, beta: float = 1.0, alpha: float = 1.0, **kwargs, ): ...
def __init__( self, labels: RVconf = RVconf(10, 'onehot', projection=True, name="Labels"), alpha: float = 10., ss_strategy: Literal['sum', 'logsumexp', 'mean', 'max', 'min'] = 'logsumexp', name: str = 'SemiFactorVAE', **kwargs, ): super().__init__(ss_strategy=ss_strategy, labels=labels, name=name, **kwargs) self.n_labels = self.discriminator.n_outputs self.alpha = tf.convert_to_tensor(alpha, dtype=self.dtype, name='alpha')
def __init__( self, labels: RVconf = RVconf(10, 'onehot', projection=True, name="digits"), alpha: float = 10.0, mi_coef: Union[float, Interpolation] = linear(vmin=0.1, vmax=0.05, steps=20000), reverse_mi: bool = False, steps_without_mi: int = 1000, **kwargs, ): super().__init__(**kwargs) self._separated_steps = False self.labels = _parse_layers(labels) self._mi_coef = mi_coef self.alpha = alpha self.steps_without_mi = int(steps_without_mi) self.reverse_mi = bool(reverse_mi)
def __init__( self, batchnorm: bool = False, input_dropout: float = 0., dropout: float = 0., units: Sequence[int] = (1000, 1000, 1000, 1000, 1000), observation: Union[RVconf, Sequence[RVconf]] = RVconf(1, 'bernoulli', projection=True, name="discriminator"), activation: Union[str, Callable[[tf.Tensor], tf.Tensor]] = tf.nn.leaky_relu, ss_strategy: Literal['sum', 'logsumexp', 'mean', 'max', 'min'] = 'logsumexp', name: str = "FactorDiscriminator", ): if not isinstance(observation, (tuple, list)): observation = [observation] assert len( observation) > 0, "No output is given for FactorDiscriminator" assert all( isinstance(o, (RVconf, DistributionDense)) for o in observation), ( f"outputs must be instance of RVmeta, but given:{observation}") n_outputs = 0 for o in observation: if not o.projection: warnings.warn(f'Projection turn off for observation {o}!') o.event_shape = (int(np.prod(o.event_shape)), ) n_outputs += o.event_shape[0] layers = dense_network(units=units, batchnorm=batchnorm, dropout=dropout, flatten_inputs=True, input_dropout=input_dropout, activation=activation, prefix=name) super().__init__(layers, name=name) self.ss_strategy = str(ss_strategy) self.observation = observation self.n_outputs = n_outputs self._distributions = [] assert self.ss_strategy in {'sum', 'logsumexp', 'mean', 'max', 'min'}
def __init__( self, labels: RVconf = RVconf(10, 'relaxedonehot', name='digits'), classifier: Sequence[int] = (1024, 1024, 1024, 1024), activation: Union[str, Callable[[Any], tf.Tensor]] = 'relu', alpha: float = 10., **kwargs, ): super().__init__(**kwargs) self.alpha = float(alpha) ## the networks # TODO: force reparams here self.labels = _parse_layers(labels) self.n_classes = int(np.prod(labels.event_shape)) self.classifier = NetConf(classifier, flatten_inputs=True, activation=activation, name='Classifier').create_network() self.xy_to_qz_net = NetConf([128, 128], activation=activation, name='xy_to_qz').create_network() self.zy_to_px_net = NetConf([128, 128], activation=activation, name='zy_to_px').create_network() ## check the labels distribution self.n_classes = int(np.prod(labels.event_shape)) # q(z|xy) self.y_to_qz = Dense(128, activation='linear', name='y_to_qz') self.x_to_qz = Dense(128, activation='linear', name='x_to_qz') # p(x|zy) self.y_to_px = Dense(128, activation='linear', name='z_to_px') self.z_to_px = Dense(128, activation='linear', name='z_to_px') self.concat = Concatenate(axis=-1) self.flatten = Flatten() self.onehot_dist = DistributionLambda(lambda p: VectorDeterministic(p)) # classes if self.n_classes in (10, 3, 4): self.classes = [self.n_classes] elif self.n_classes == (15 + 8 + 4 + 30): # dsprites self.classes = [15, 8, 4, 10, 10, 10] elif self.n_classes == (40 + 6 + 3 + 32 + 32): # shapes3d self.classes = [40, 6, 3, 32, 32] else: raise NotImplementedError
def __init__( self, ldd: LatentDirichletDecoder, encoder: LayerCreator = NetConf(flatten_inputs=True, name="Encoder"), decoder: LayerCreator = NetConf(flatten_inputs=True, name="Decoder"), latents: LayerCreator = RVconf(10, posterior='mvndiag', projection=True, name="Latents"), **kwargs, ): super().__init__(ldd=ldd, latents=latents, encoder=encoder, decoder=decoder, **kwargs) # this layer won't train the KL divergence or the encoder self.encoder.trainable = False for l in self.latent_layers: l.trainable = False
def __init__(self, discriminator_units: Sequence[int] = (1000, 1000, 1000, 1000, 1000), discriminator_optim: Optional[tf.optimizers.Optimizer] = None, activation: Union[str, Callable[[], Any]] = tf.nn.relu, batchnorm: bool = False, tc_coef: float = 7.0, maximize_tc: bool = False, name: str = 'FactorVAE', **kwargs): ss_strategy = kwargs.pop('ss_strategy', 'logsumexp') labels = kwargs.pop( 'labels', RVconf(1, 'bernoulli', projection=True, name="discriminator")) super().__init__(name=name, **kwargs) self.tc_coef = tf.convert_to_tensor(tc_coef, dtype=self.dtype, name='tc_coef') ## init discriminator self.discriminator = FactorDiscriminator( units=as_tuple(discriminator_units), activation=activation, batchnorm=batchnorm, ss_strategy=ss_strategy, observation=labels) if discriminator_optim is None: discriminator_optim = tf.optimizers.Adam(learning_rate=1e-5, beta_1=0.5, beta_2=0.9) self.discriminator_optim = discriminator_optim ## Discriminator and VAE must be trained separately self.disc_params = [] self.vae_params = [] self.maximize_tc = bool(maximize_tc) ## For training # store class for training factor discriminator, this allow later # modification without re-writing the train_steps method self._is_pretraining = False
def __init__( self, labels: RVconf = RVconf(10, 'onehot', name='digits'), observation: RVconf = RVconf((28, 28, 1), 'bernoulli', projection=True, name='image'), latents: RVconf = RVconf(64, 'mvndiag', projection=True, name='latents'), classifier: LayerCreator = NetConf([128, 128], flatten_inputs=True, name='classifier'), encoder: LayerCreator = NetConf([512, 512], flatten_inputs=True, name='encoder'), decoder: LayerCreator = NetConf([512, 512], flatten_inputs=True, name='decoder'), xy_to_qz: LayerCreator = NetConf([128, 128], name='xy_to_qz'), zy_to_px: LayerCreator = NetConf([128, 128], name='zy_to_px'), embedding_dim: int = 128, embedding_method: Literal['repetition', 'projection', 'dictionary', 'sequential', 'identity'] = 'sequential', batchnorm: str = False, dropout: float = 0., alpha: float = 0.05, beta: float = 1., temperature: float = 10., marginalize: bool = True, name: str = 'ConditionalM2VAE', **kwargs, ): super().__init__(latents=latents, observation=observation, encoder=encoder, decoder=decoder, beta=beta, name=name, **kwargs) self.alpha = tf.convert_to_tensor(alpha, dtype=self.dtype, name="alpha") self.embedding_dim = int(embedding_dim) self.embedding_method = str(embedding_method) self.batchnorm = bool(batchnorm) self.dropout = float(dropout) ## the networks self.classifier = _parse_layers(classifier) self.xy_to_qz_net = _parse_layers(xy_to_qz) self.zy_to_px_net = _parse_layers(zy_to_px) ## check the labels distribution if hasattr(labels, 'posterior'): posterior_name = str(labels.posterior) if hasattr(labels, 'posterior_layer'): posterior_name = str(labels.posterior_layer).lower() if 'onehot' not in posterior_name: warnings.warn( 'Conditional VAE only support one-hot or relaxed one-hot distribution, ' f'but given: {labels}') self.n_classes = int(np.prod(labels.event_shape)) self.marginalize = bool(marginalize) # labels distribution if marginalize: temperature = 0 if temperature == 0.: posterior = 'onehot' dist_kw = dict() self.relaxed = False else: posterior = 'relaxedonehot' dist_kw = dict(temperature=temperature) self.relaxed = True self.labels = RVconf(self.n_classes, posterior, projection=True, prior=OneHotCategorical(probs=[1. / self.n_classes] * self.n_classes), name=labels.name, kwargs=dist_kw).create_posterior() # create embedder embedder = get_embedding(self.embedding_method) # q(z|xy) self.y_to_qz = embedder(n_classes=self.n_classes, event_shape=self.embedding_dim, name='y_to_qz') self.x_to_qz = Dense(embedding_dim, activation='linear', name='x_to_qz') # p(x|zy) self.y_to_px = embedder(n_classes=self.n_classes, event_shape=self.embedding_dim, name='y_to_px') self.z_to_px = Dense(embedding_dim, activation='linear', name='z_to_px') # batch normalization if self.batchnorm: self.qz_xy_norm = BatchNormalization(axis=-1, name='qz_xy_norm') self.px_zy_norm = BatchNormalization(axis=-1, name='px_zy_norm') if 0.0 < self.dropout < 1.0: self.qz_xy_drop = Dropout(rate=self.dropout, name='qz_xy_drop') self.px_zy_drop = Dropout(rate=self.dropout, name='px_zy_drop')
def mnist_networks( qz: str = 'mvndiag', zdim: Optional[int] = None, activation: Callable[[tf.Tensor], tf.Tensor] = tf.nn.elu, is_semi_supervised: bool = False, is_hierarchical: bool = False, centerize_image: bool = True, skip_generator: bool = False, **kwargs, ) -> Dict[str, Layer]: """Network for MNIST dataset image size (28, 28, 1)""" from odin.bay.random_variable import RVconf from odin.bay.vi import BiConvLatents n_channels = int(kwargs.get('n_channels', 1)) proj_dim = 196 input_shape = (28, 28, n_channels) if zdim is None: zdim = 32 conv, deconv = _prepare_cnn(activation=activation) n_params, observation, last_layer = _parse_distribution( input_shape, kwargs.get('distribution', 'bernoulli')) encoder = SequentialNetwork( [ keras.layers.InputLayer(input_shape), CenterAt0(enable=centerize_image), conv(32, 5, strides=1, name='encoder0'), # 28, 28, 32 conv(32, 5, strides=2, name='encoder1'), # 14, 14, 32 conv(64, 5, strides=1, name='encoder2'), # 14, 14, 64 conv(64, 5, strides=2, name='encoder3'), # 7 , 7 , 64 keras.layers.Flatten(), keras.layers.Dense( proj_dim, activation='linear', name='encoder_proj') ], name='Encoder', ) layers = [ keras.layers.Dense(proj_dim, activation='linear', name='decoder_proj'), keras.layers.Reshape((7, 7, proj_dim // 49)), # 7, 7, 4 deconv(64, 5, strides=2, name='decoder2'), # 14, 14, 64 BiConvLatents( conv(64, 5, strides=1, name='decoder3'), # 14, 14, 64 encoder=encoder.layers[3], filters=16, kernel_size=14, strides=7, disable=True, name='latents2'), deconv(32, 5, strides=2, name='decoder4'), # 28, 28, 32 conv(32, 5, strides=1, name='decoder5'), # 28, 28, 32 conv(n_channels * n_params, 1, strides=1, activation='linear', name='decoder6'), last_layer ] layers = [ i.layer if isinstance(i, BiConvLatents) and not is_hierarchical else i for i in layers ] if skip_generator: decoder = SkipSequential(layers=layers, name='SkipDecoder') else: decoder = SequentialNetwork(layers=layers, name='Decoder') latents = RVconf((zdim, ), qz, projection=True, name="latents").create_posterior() networks = dict(encoder=encoder, decoder=decoder, observation=observation, latents=latents) if is_semi_supervised: networks['labels'] = RVconf( 10, 'onehot', projection=True, name=kwargs.get('labels_name', 'digits'), ).create_posterior() return networks
def __init__(self, latents=RVconf(5, 'mvndiag', projection=True, name='Latents'), factors=RVconf(5, 'mvndiag', projection=True, name='Factors'), **kwargs): super().__init__(latents=latents, factors=factors, **kwargs)
def __init__( self, observation: LayerCreator = RVconf((28, 28, 1), 'bernoulli', projection=True, name='image'), latents: Optional[LayerCreator] = RVconf(16, 'mvndiag', projection=True, name="latents"), encoder: Optional[LayerCreator] = None, decoder: Optional[LayerCreator] = None, **kwargs, ): if encoder is None: encoder = NetConf((512, 512), flatten_inputs=True, name="encoder") if decoder is None: decoder = NetConf((512, 512), flatten_inputs=True, flatten_outputs=True, name="decoder") ### keras want this supports_masking on to enable support masking super().__init__(**kwargs) ### create layers # encoder if isinstance(encoder, (tuple, list)): self._encoder = [ _parse_layers(network=e, name=f"encoder{i}") for i, e in enumerate(encoder) ] self._encoder_args = [_get_args(e) for e in self._encoder] else: self._encoder = _parse_layers(network=encoder, name="encoder") self._encoder_args = _get_args(self.encoder) # latents if isinstance(latents, (tuple, list)): self._latents = [ _parse_layers(network=z, name=f"latents{i}") for i, z in enumerate(latents) ] self._latents_args = [_get_args(z) for z in self.latents] else: self._latents = _parse_layers(network=latents, name="latents") self._latents_args = _get_args(self.latents) # decoder if isinstance(decoder, (tuple, list)): self._decoder = [ _parse_layers(network=d, name=f"decoder{i}") for i, d in enumerate(decoder) ] self._decoder_args = [_get_args(d) for d in self.decoder] else: self._decoder = _parse_layers(network=decoder, name="decoder") self._decoder_args = _get_args(self.decoder) # observation if isinstance(observation, (tuple, list)): self._observation = [ _parse_layers(network=observation, name=f"observation{i}") for i, o in enumerate(observation) ] self._observation_args = [_get_args(o) for o in self.observation] else: self._observation = _parse_layers(network=observation, name="observation") self._observation_args = _get_args(self.observation)
def dsprites_networks( qz: str = 'mvndiag', zdim: Optional[int] = None, activation: Callable[[tf.Tensor], tf.Tensor] = tf.nn.elu, is_semi_supervised: bool = False, is_hierarchical: bool = False, centerize_image: bool = True, skip_generator: bool = False, **kwargs, ) -> Dict[str, Layer]: from odin.bay.random_variable import RVconf from odin.bay.vi.autoencoder import BiConvLatents if zdim is None: zdim = 10 n_channels = int(kwargs.get('n_channels', 1)) input_shape = (64, 64, n_channels) conv, deconv = _prepare_cnn(activation=activation) proj_dim = kwargs.get('proj_dim', None) if proj_dim is None: proj_dim = 128 if n_channels == 1 else 256 else: proj_dim = int(proj_dim) n_params, observation, last_layer = _parse_distribution( input_shape, kwargs.get('distribution', 'bernoulli')) encoder = SequentialNetwork( [ CenterAt0(enable=centerize_image), conv(32, 4, strides=2, name='encoder0'), conv(32, 4, strides=2, name='encoder1'), conv(64, 4, strides=2, name='encoder2'), conv(64, 4, strides=2, name='encoder3'), keras.layers.Flatten(), keras.layers.Dense( proj_dim, activation='linear', name='encoder_proj') ], name='Encoder', ) # layers = [ # keras.layers.Dense(proj_dim, activation='linear', name='decoder_proj'), # keras.layers.Reshape((4, 4, proj_dim // 16)), # BiConvLatents(deconv(64, 4, strides=2, name='decoder1'), # encoder=encoder.layers[3], # filters=32, kernel_size=8, strides=4, # disable=True, name='latents1'), # deconv(64, 4, strides=2, name='decoder2'), # BiConvLatents(deconv(32, 4, strides=2, name='decoder3'), # encoder=encoder.layers[1], # filters=16, kernel_size=8, strides=4, # disable=True, name='latents2'), # deconv(32, 4, strides=2, name='decoder4'), # # NOTE: this last projection layer with linear activation is crucial # # otherwise the distribution parameterized by this layer won't converge # conv(n_channels * n_params, # 1, # strides=1, # activation='linear', # name='decoder6'), # last_layer # ] layers = [ keras.layers.Dense(proj_dim, activation='linear', name='decoder_proj'), keras.layers.Reshape((4, 4, proj_dim // 16)), BiConvLatents(deconv(64, 4, strides=2, name='decoder1'), encoder=encoder.layers[3], filters=32, kernel_size=8, strides=4, disable=True, name='latents2'), deconv(64, 4, strides=2, name='decoder2'), deconv(32, 4, strides=2, name='decoder3'), deconv(32, 4, strides=2, name='decoder4'), # NOTE: this last projection layer with linear activation is crucial # otherwise the distribution parameterized by this layer won't converge conv(n_channels * n_params, 1, strides=1, activation='linear', name='decoder6'), last_layer ] layers = [ i.layer if isinstance(i, BiConvLatents) and not is_hierarchical else i for i in layers ] if skip_generator: decoder = SkipSequential(layers=layers, name='SkipDecoder') else: decoder = SequentialNetwork(layers=layers, name='Decoder') latents = RVconf((zdim, ), qz, projection=True, name="latents").create_posterior() networks = dict(encoder=encoder, decoder=decoder, observation=observation, latents=latents) if is_semi_supervised: from odin.bay.layers.dense_distribution import DistributionDense # TODO: update networks['labels'] = DistributionDense( event_shape=(5, ), posterior=_dsprites_distribution, units=9, name='geometry2d') return networks
def celeba_networks(qz: str = 'mvndiag', zdim: Optional[int] = None, activation: Union[Callable, str] = tf.nn.elu, is_semi_supervised: bool = False, is_hierarchical: bool = False, centerize_image: bool = True, skip_generator: bool = False, n_labels: int = 18, **kwargs): from odin.bay.random_variable import RVconf if zdim is None: zdim = 45 input_shape = (64, 64, 3) n_components = 10 # for Mixture Quantized Logistic n_channels = input_shape[-1] conv, deconv = _prepare_cnn(activation=activation) proj_dim = 512 encoder = SequentialNetwork( [ CenterAt0(enable=centerize_image), conv(32, 4, strides=2, name='encoder0'), conv(32, 4, strides=2, name='encoder1'), conv(64, 4, strides=2, name='encoder2'), conv(64, 4, strides=1, name='encoder3'), keras.layers.Flatten(), keras.layers.Dense( proj_dim, activation='linear', name='encoder_proj') ], name='Encoder', ) layers = [ keras.layers.Dense(proj_dim, activation='linear', name='decoder_proj'), keras.layers.Reshape((8, 8, proj_dim // 64)), deconv(64, 4, strides=1, name='decoder1'), deconv(64, 4, strides=2, name='decoder2'), deconv(32, 4, strides=2, name='decoder3'), deconv(32, 4, strides=2, name='decoder4'), conv( 2 * n_channels, # MixtureQuantizedLogistic.params_size(n_components, n_channels), 1, strides=1, activation='linear', name='decoder5'), ] from odin.bay import BiConvLatents layers = [ i.layer if isinstance(i, BiConvLatents) and not is_hierarchical else i for i in layers ] if skip_generator: decoder = SkipSequential(layers=layers, name='SkipDecoder') else: decoder = SequentialNetwork(layers=layers, name='Decoder') latents = RVconf((zdim, ), qz, projection=True, name="latents").create_posterior() observation = _parse_distribution(input_shape, 'qlogistic') networks = dict(encoder=encoder, decoder=decoder, observation=observation, latents=latents) if is_semi_supervised: from odin.bay.layers import DistributionDense networks['labels'] = DistributionDense(event_shape=n_labels, posterior=_celeba_distribution, units=n_labels, name='attributes') return networks
def pbmc_networks( qz: str = 'mvndiag', zdim: Optional[int] = 32, activation: Callable[[tf.Tensor], tf.Tensor] = tf.nn.elu, is_semi_supervised: bool = False, is_hierarchical: bool = False, log_norm: bool = True, cnn: bool = True, units: Sequence[int] = (512, 512, 512), **kwargs, ) -> Dict[str, Layer]: """Network for Cortex mRNA sequencing datasets""" from odin.bay.random_variable import RVconf input_shape = (2019, ) n_labels = 32 if zdim is None: zdim = 32 ## dense network if not cnn: encoder = SequentialNetwork( [LogNorm(enable=log_norm)] + [ keras.layers.Dense( u, activation=activation, name=f'encoder{i}') for i, u in enumerate(units) ], name='encoder', ) decoder = SequentialNetwork( [ keras.layers.Dense( u, activation=activation, name=f'decoder{i}') for i, u in enumerate(units) ], name='decoder', ) ## conv network else: Conv1D = partial(keras.layers.Conv1D, strides=2, padding='same', activation=activation) Conv1DTranspose = partial(keras.layers.Conv1DTranspose, strides=2, padding='same', activation=activation) encoder = SequentialNetwork( [ LogNorm(enable=log_norm), keras.layers.Lambda( lambda x: tf.expand_dims(x, axis=-1)), # (n, 2019, 1) Conv1D(32, 7, name='encoder0'), Conv1D(64, 5, name='encoder1'), Conv1D(128, 3, name='encoder2'), Conv1D(128, 3, name='encoder3'), keras.layers.Flatten() ], name='encoder', ) decoder = SequentialNetwork( [ keras.layers.Dense(256, activation=activation, name='decoder0'), keras.layers.Lambda( lambda x: tf.expand_dims(x, axis=-1)), # (n, 256, 1) Conv1DTranspose(128, 3, strides=1, name='decoder1'), Conv1DTranspose(128, 3, name='decoder2'), Conv1DTranspose(64, 5, name='decoder3'), Conv1DTranspose(32, 7, name='decoder4'), Conv1DTranspose(1, 1, strides=1, name='decoder5'), keras.layers.Flatten() ], name='decoder', ) latents = RVconf((zdim, ), qz, projection=True, name="latents").create_posterior() observation = RVconf(input_shape, "zinb", projection=True, name="mrna").create_posterior() networks = dict(encoder=encoder, decoder=decoder, observation=observation, latents=latents) if is_semi_supervised: networks['labels'] = RVconf(n_labels, 'nb', projection=True, name='adt').create_posterior() return networks
def __init__( self, n_classes: int = 10, observation=RVconf((28, 28, 1), 'bernoulli', projection=True, name='image'), latents: RVconf = RVconf(64, 'mvndiag', projection=True, name='latents'), classifier: LayerCreator = NetConf([128, 128], flatten_inputs=True, name='classifier'), auxiliary: RVconf = RVconf(64, 'mvndiag', projection=True, name='auxiliary'), encoder_a: LayerCreator = NetConf([512, 512], flatten_inputs=True, name='encoder_a'), decoder_a: LayerCreator = NetConf([512, 512], flatten_inputs=True, name='decoder_a'), encoder: LayerCreator = NetConf([512, 512], flatten_inputs=True, name='encoder'), decoder: LayerCreator = NetConf([512, 512], flatten_inputs=True, name='decoder'), axy_to_qz: LayerCreator = NetConf([128, 128], name='axy_to_qz'), azy_to_px: LayerCreator = NetConf([128, 128], name='azy_to_px'), embedding_dim: int = 128, embedding_method: Literal['repetition', 'projection', 'dictionary', 'sequential', 'identity'] = 'sequential', batchnorm: bool = False, dropout: float = 0., skip_connection: bool = True, alpha: float = 1.0, beta: float = 1.0, temperature: float = 10., marginalize: bool = True, name='AuxiliaryVAE', **kwargs, ): super().__init__(n_classes=n_classes, observation=observation, latents=latents, classifier=classifier, encoder=encoder, decoder=decoder, xy_to_qz=axy_to_qz, zy_to_px=azy_to_px, embedding_dim=embedding_dim, embedding_method=embedding_method, batchnorm=batchnorm, dropout=dropout, alpha=alpha, beta=beta, temperature=temperature, marginalize=marginalize, name=name, **kwargs) self.skip_connection = bool(skip_connection) self.batchnorm = bool(batchnorm) self.qa_dist = auxiliary.create_posterior(name='qa_x') self.pa_dist = auxiliary.create_posterior(name='pa_xz') self.encoder_a = _parse_layers(encoder_a) self.decoder_a = _parse_layers(decoder_a) # labels connections self.x_to_qy = Dense(units=self.embedding_dim, activation='linear') self.a_to_qy = Dense(units=self.embedding_dim, activation='linear') # auxiliary connections self.a_to_qz = Dense(units=self.embedding_dim, activation='linear') self.a_to_px = Dense(units=self.embedding_dim, activation='linear') # for p(a|yz) self.y_to_pa = Dense(units=self.embedding_dim, activation='linear') self.z_to_pa = Dense(units=self.embedding_dim, activation='linear') # batchnorm and dropout if self.batchnorm: self.qy_ax_norm = BatchNormalization(axis=-1, name='qy_ax_norm') self.pa_zy_norm = BatchNormalization(axis=-1, name='pa_zy_norm') if 0.0 < self.dropout < 1.0: self.qy_ax_drop = Dropout(rate=self.dropout, name='qy_ax_drop') self.pa_zy_drop = Dropout(axis=-1, name='pa_zy_drop')
def cifar_networks( qz: str = 'mvndiag', zdim: Optional[int] = None, activation: Callable[[tf.Tensor], tf.Tensor] = tf.nn.elu, is_semi_supervised: bool = False, is_hierarchical: bool = False, centerize_image: bool = True, skip_generator: bool = False, **kwargs, ) -> Dict[str, Layer]: """Network for CIFAR dataset image size (32, 32, 3)""" from odin.bay.random_variable import RVconf from odin.bay.vi.autoencoder.hierarchical_vae import BiConvLatents if zdim is None: zdim = 256 n_channels = kwargs.get('n_channels', 3) input_shape = (32, 32, n_channels) conv, deconv = _prepare_cnn(activation=activation) n_classes = kwargs.get('n_classes', 10) proj_dim = 8 * 8 * 8 ## output distribution n_params, observation, last_layer = _parse_distribution( input_shape, kwargs.get('distribution', 'qlogistic')) ## encoder encoder = SequentialNetwork( [ CenterAt0(enable=centerize_image), conv(32, 4, strides=1, name='encoder0'), # 32, 32, 32 conv(32, 4, strides=2, name='encoder1'), # 16, 16, 32 conv(64, 4, strides=1, name='encoder2'), # 16, 16, 64 conv(64, 4, strides=2, name='encoder3'), # 8, 8, 64 keras.layers.Flatten(), keras.layers.Dense( proj_dim, activation='linear', name='encoder_proj') ], name='Encoder', ) layers = [ keras.layers.Dense(proj_dim, activation='linear', name='decoder_proj'), keras.layers.Reshape((8, 8, proj_dim // 64)), # 8, 8, 4 deconv(64, 4, strides=2, name='decoder1'), # 16, 16, 64 BiConvLatents( conv(64, 4, strides=1, name='decoder2'), # 16, 16, 64 encoder=encoder.layers[3], filters=32, kernel_size=8, strides=4, disable=True, name='latents1'), deconv(32, 4, strides=2, name='decoder3'), # 32, 32, 32 BiConvLatents( conv(32, 4, strides=1, name='decoder4'), # 32, 32, 32 encoder=encoder.layers[1], filters=16, kernel_size=8, strides=4, disable=True, name='latents2'), conv( n_channels * n_params, # 32, 32, 3 1, strides=1, activation='linear', name='decoder5'), last_layer ] layers = [ i.layer if isinstance(i, BiConvLatents) and not is_hierarchical else i for i in layers ] if skip_generator: decoder = SkipSequential(layers=layers, name='SkipDecoder') else: decoder = SequentialNetwork(layers=layers, name='Decoder') ## others latents = RVconf((zdim, ), qz, projection=True, name="latents").create_posterior() # create the observation of MixtureQuantizedLogistic networks = dict(encoder=encoder, decoder=decoder, observation=observation, latents=latents) if is_semi_supervised: networks['labels'] = RVconf(n_classes, 'onehot', projection=True, name='labels').create_posterior() return networks