def _build_bottleneck(model_type, latent_dim): if model_type == 'vanilla' or model_type == 'stacked' or model_type == 'denoising' or model_type == 'shallow': bottleneck = bottlenecks.IdentityBottleneck(latent_dim) elif model_type == 'vae': bottleneck = bottlenecks.VariationalBottleneck(latent_dim) elif model_type == 'beta_vae_strict': bottleneck = bottlenecks.VariationalBottleneck(latent_dim, beta=2.) elif model_type == 'beta_vae_loose': bottleneck = bottlenecks.VariationalBottleneck(latent_dim, beta=0.5) elif model_type == 'sparse': bottleneck = bottlenecks.SparseBottleneck(latent_dim, sparsity=0.25) elif model_type == 'vq': bottleneck = bottlenecks.VectorQuantizedBottleneck(latent_dim, num_categories=512) else: raise ValueError(f'Unknown model type {model_type}.') return bottleneck
def setUp(self): encoder = encoders.DenseEncoder((1, 32, 32), 3, 64) bottleneck = bottlenecks.VariationalBottleneck(32) self.net = Classifier(encoder, bottleneck, 10) self.test_inputs = torch.randn(16, 1, 32, 32) self.output_shape = torch.Size((16, 10))
def setUp(self): self.neck = bottlenecks.VariationalBottleneck(1)