def __init__(self, args: Arguments, free_bits=None, beta=1, **kwargs): networks = get_networks(args.ds, zdim=args.zdim, is_hierarchical=False, is_semi_supervised=False) zdim = args.zdim prior = Normal(loc=tf.zeros([zdim]), scale=tf.ones([zdim])) networks['latents'] = DistributionDense(units=zdim * 2, posterior=make_normal, prior=prior, name=networks['latents'].name) super().__init__(free_bits=free_bits, beta=beta, **networks, **kwargs)
def __init__(self, args: Arguments, **kwargs): networks = get_networks(args.ds, zdim=args.zdim, is_hierarchical=False, is_semi_supervised=False) zdim = args.zdim prior = Gamma(rate=tf.fill([zdim], 0.3), concentration=tf.fill([zdim], 0.3)) networks['latents'] = DistributionDense(units=zdim * 2, posterior=make_gamma, prior=prior, name=networks['latents'].name) super().__init__(**networks, **kwargs)
def model_fullcov(args: Arguments): nets = get_networks(args.ds, zdim=args.zdim, is_hierarchical=False, is_semi_supervised=False) zdims = int(np.prod(nets['latents'].event_shape)) nets['latents'] = RVconf( event_shape=zdims, projection=True, posterior='mvntril', prior=Independent(Normal(tf.zeros([zdims]), tf.ones([zdims])), 1), name='latents').create_posterior() return VariationalAutoencoder(**nets, name='FullCov')
def __init__(self, args: Arguments, free_bits=None, beta=1., **kwargs): networks = get_networks(args.ds, zdim=args.zdim, is_hierarchical=False, is_semi_supervised=False) obs: DistributionDense = networks['observation'] event_shape = obs.event_shape obs_new = DistributionDense( event_shape=event_shape, units=int(np.prod(event_shape)) * 2, posterior=partial(make_gaussian_out, event_shape=event_shape), name='image') networks['observation'] = obs_new super().__init__(free_bits=free_bits, beta=beta, **networks, **kwargs)
def __init__(self, args: Arguments, **kwargs): networks = get_networks(args.ds, zdim=args.zdim, is_hierarchical=False, is_semi_supervised=False) zdim = args.zdim prior = Normal(loc=tf.zeros([zdim]), scale=tf.ones([zdim])) latents = [ DistributionDense(units=zdim * 2, posterior=make_normal, prior=prior, name='latents1'), DistributionDense(units=zdim * 2, posterior=make_normal, prior=prior, name='latents2') ] networks['latents'] = latents super().__init__(**networks, **kwargs)
def create_model(args: Namespace) -> VariationalModel: networks = get_networks(args.ds, zdim=args.zdim, is_semi_supervised=True) name = args.vae for k, v in globals().items(): if (isinstance(v, type) and issubclass(v, VariationalModel) and name.lower() == k.lower()): spec = inspect.getfullargspec(v.__init__) networks = { k: v for k, v in networks.items() if k in spec.args or k in spec.kwonlyargs } model = v(**networks) model.build([None] + IMAGE_SHAPE) return model raise ValueError(f'Cannot find model with name: {name}')
def model_gmmprior(args: Arguments): nets = get_networks(args.ds, zdim=args.zdim, is_hierarchical=False, is_semi_supervised=False) latent_size = np.prod(nets['latents'].event_shape) n_components = 100 loc = tf.compat.v1.get_variable(name="loc", shape=[n_components, latent_size]) raw_scale_diag = tf.compat.v1.get_variable( name="raw_scale_diag", shape=[n_components, latent_size]) mixture_logits = tf.compat.v1.get_variable( name="mixture_logits", shape=[n_components]) nets['latents'].prior = MixtureSameFamily( components_distribution=MultivariateNormalDiag( loc=loc, scale_diag=tf.nn.softplus(raw_scale_diag) + tf.math.exp(-7.)), mixture_distribution=Categorical(logits=mixture_logits), name="prior") return VariationalAutoencoder(**nets, name='GMMPrior')
def model_fullcovgmm(args: Arguments): nets = get_networks(args.ds, zdim=args.zdim, is_hierarchical=False, is_semi_supervised=False) latent_size = int(np.prod(nets['latents'].event_shape)) n_components = 100 loc = tf.compat.v1.get_variable(name="loc", shape=[n_components, latent_size]) raw_scale_diag = tf.compat.v1.get_variable( name="raw_scale_diag", shape=[n_components, latent_size]) mixture_logits = tf.compat.v1.get_variable( name="mixture_logits", shape=[n_components]) nets['latents'] = RVconf( event_shape=latent_size, projection=True, posterior='mvntril', prior=MixtureSameFamily( components_distribution=MultivariateNormalDiag( loc=loc, scale_diag=tf.nn.softplus(raw_scale_diag) + tf.math.exp(-7.)), mixture_distribution=Categorical(logits=mixture_logits), name="prior"), name='latents').create_posterior() return VariationalAutoencoder(**nets, name='FullCov')
def get_dense_networks(args: Arguments): networks = get_networks('mnist', is_semi_supervised=False, is_hierarchical=False, zdim=args.zdim) networks['encoder'] = SequentialNetwork([ InputLayer(input_shape=[28, 28, 1]), CenterAt0(), Flatten(), Dense(1024, activation='relu'), Dense(1024, activation='relu'), Dense(1024, activation='relu'), ], name='Encoder') networks['decoder'] = SequentialNetwork([ InputLayer(input_shape=[args.zdim]), Dense(1024, activation='relu'), Dense(1024, activation='relu'), Dense(1024, activation='relu'), Dense(28 * 28 * 1, activation='linear'), Reshape([28, 28, 1]), ], name='Decoder') return networks
def get_model( args: Arguments, return_dataset: bool = True, encoder: Any = None, decoder: Any = None, latents: Any = None, observation: Any = None, **kwargs ) -> Union[VariationalModel, Tuple[VariationalModel, ImageDataset]]: ds = _DS[args.ds] vae_name = args.vae vae_cls = get_vae(vae_name) networks = get_networks(ds.name, is_semi_supervised=vae_cls.is_semi_supervised(), is_hierarchical=vae_cls.is_hierarchical(), zdim=None if args.zdim < 1 else args.zdim) for k, v in locals().items(): if k in networks and v is not None: networks[k] = v vae = vae_cls(**networks, **kwargs) vae.build(ds.full_shape) if return_dataset: return vae, ds return vae
def model_gmmvae2(args: Arguments): return GMMVAE(n_components=50, **get_networks(args.ds, zdim=args.zdim, is_hierarchical=False, is_semi_supervised=False))
def model_iwgamma2(args: Arguments): return IWGammaVAE(gamma=2.0, n_iw=10, **get_networks(args.ds, zdim=args.zdim, is_hierarchical=False, is_semi_supervised=False))
def main(vae, ds, args, parser): n_labeled = SEMI_SETTING[ds] vae = get_vae(vae) ds = get_dataset(ds) batch_size = BS_SETTING[ds.name] assert isinstance( ds, ImageDataset), f'Only support image dataset but given {ds}' vae: Type[VariationalAutoencoder] ds: ImageDataset is_semi = vae.is_semi_supervised() ## skip unsupervised system, if there are semi-supervised modifications if not is_semi: for key in ('alpha', 'coef', 'ratio'): if parser.get_default(key) != getattr(args, key): print('Skip semi-supervised training for:', args) return ## prepare the arguments kw = {} ## path name = f'{ds.name}_{vae.__name__.lower()}' path = f'{ROOT}/{name}' anno = [] if args.zdim > 0: anno.append(f'z{int(args.zdim)}') if is_semi: anno.append(f'a{args.alpha:g}') anno.append(f'r{args.ratio:g}') kw['alpha'] = args.alpha if issubclass(vae, (SemafoBase, MIVAE)): anno.append(f'c{args.coef:g}') kw['mi_coef'] = args.coef if len(anno) > 0: path += f"_{'_'.join(anno)}" if args.override and os.path.exists(path): shutil.rmtree(path) print('Override:', path) if not os.path.exists(path): os.makedirs(path) print(path) ## data train = ds.create_dataset('train', batch_size=batch_size, label_percent=n_labeled if is_semi else False, oversample_ratio=args.ratio) valid = ds.create_dataset('valid', label_percent=1.0, batch_size=batch_size // 2) ## create model vae = vae( **kw, **get_networks(ds.name, zdim=int(args.zdim), is_semi_supervised=is_semi)) vae.build((None, ) + ds.shape) print(vae) vae.load_weights(f'{path}/model', verbose=True) best_llk = [] ## training def callback(): llk = [] y_true = [] y_pred = [] for x, y in tqdm(valid.take(500)): P, Q = vae(x, training=False) P = as_tuple(P) llk.append(P[0].log_prob(x)) if is_semi: y_true.append(np.argmax(y, -1)) y_pred.append(np.argmax(get_ymean(P[1]), -1)) # accuracy if is_semi: y_true = np.concatenate(y_true, axis=0) y_pred = np.concatenate(y_pred, axis=0) acc = accuracy_score(y_true=y_true, y_pred=y_pred) else: acc = 0 # log-likelihood llk = tf.reduce_mean(tf.concat(llk, 0)) best_llk.append(llk) text = f'#{vae.step.numpy()} llk={llk:.2f} acc={acc:.2f}' if llk >= np.max(best_llk): vae.save_weights(f'{path}/model') vae.trainer.print(f'best llk {text}') else: vae.trainer.print(f'worse llk {text}') # tensorboard summary tf.summary.scalar('llk_valid', llk) tf.summary.scalar('acc_valid', acc) optim_info = get_optimizer_info(ds.name, batch_size=batch_size) if args.it > 0: optim_info['max_iter'] = int(args.it) vae.fit( train, skip_fitted=True, logdir=path, on_valid_end=callback, clipnorm=100, logging_interval=10, valid_interval=180, nan_gradients_policy='stop', **optim_info, ) ## evaluating vae.load_weights(f'{path}/model', verbose=True) if args.eval: evaluate(vae, ds, path, f'{name}')
def model_equilibriumvae8(args: Arguments): return EquilibriumVAE( **get_networks(args.ds, zdim=args.zdim, is_hierarchical=False, is_semi_supervised=False), R=0.0, C=1.0, random_capacity=True, dropout=0.0, beta=1.)
def model_equilibriumvae9(args: Arguments): return EquilibriumVAE( **get_networks(args.ds, zdim=args.zdim, is_hierarchical=False, is_semi_supervised=False), R=-0.5, C=0., dropout=0., beta=1.)
def model_equilibriumvae3(args: Arguments): return EquilibriumVAE( **get_networks(args.ds, zdim=args.zdim, is_hierarchical=False, is_semi_supervised=False), C=1.5)
def model_rvae(args: Arguments): return VariationalAutoencoder( **get_networks(args.ds, zdim=args.zdim, is_hierarchical=False, is_semi_supervised=False), reverse=True)
def model_gvae4(args: Arguments): nets = get_networks(args.ds, zdim=args.zdim, is_hierarchical=False, is_semi_supervised=False) return BetaGammaVAE(**nets, beta=2.0, gamma=5.0)
def model_bvae3(args: Arguments): nets = get_networks(args.ds, zdim=args.zdim, is_hierarchical=False, is_semi_supervised=False) return BetaVAE(**nets, beta=0.5)
def model_bcvae2(args: Arguments): nets = get_networks(args.ds, zdim=args.zdim, is_hierarchical=False, is_semi_supervised=False) return BetaCapacityVAE(**nets, c_min=0.01, c_max=25, gamma=5, n_steps=60000)
from odin.bay import TwoStageVAE, plot_latent_stats from odin.fuel import MNIST from odin.networks import get_networks from odin import visual as vs from tqdm import tqdm from odin.ml import fast_tsne ds = MNIST() train = ds.create_dataset('train', batch_size=32) valid = ds.create_dataset('valid', batch_size=36, label_percent=1.0, drop_remainder=True) vae = TwoStageVAE(**get_networks(ds.name)) vae.build(ds.full_shape) if True: vae.load_weights('/tmp/twostagevae', verbose=True, raise_notfound=True) else: vae.fit(train, learning_rate=1e-3, max_iter=300000) vae.save_weights('/tmp/twostagevae') exit() Z = [] U = [] Z_hat = [] Y = [] for x, y in tqdm(valid): qz_x, qu_z, qz_u = vae.encode_two_stages(x) Z.append(qz_x.mean())
def run_task(args, evaluation=False): from odin.fuel import get_dataset from odin.networks import get_networks, get_optimizer_info from odin.bay.vi import DisentanglementGym, Correlation, DimReduce ######## arguments model: SemafoVAE = args['model'] dsname: str = args['ds'] py: float = args['py'] coef: Interpolation = args['coef'] model_name = model.__name__.lower() ######## prepare path logdir = f'{outdir}/{dsname}_{py}/{model_name}_{coef.name}' if OVERRIDE and os.path.exists(logdir) and not evaluation: shutil.rmtree(logdir) print(f'Override model at path: {logdir}') if not os.path.exists(logdir): os.makedirs(logdir) modelpath = f'{logdir}/model' ######## dataset ds = get_dataset(dsname) tanh_norm = True if dsname == 'celeba' else False train = ds.create_dataset('train', batch_size=32, label_percent=py, normalize='tanh' if tanh_norm else 'probs') valid = ds.create_dataset( 'valid', batch_size=32, label_percent=True, normalize='tanh' if tanh_norm else 'probs').shuffle( 1000, seed=1, reshuffle_each_iteration=True) ######## model networks = get_networks(dsname, centerize_image=False if tanh_norm else True, is_semi_supervised=True) vae = model(alpha=1. / py, mi_coef=coef, **networks) vae.build((None, ) + ds.shape) vae.load_weights(modelpath, verbose=True) ######## evaluation if evaluation: if vae.step.numpy() <= 1: return evaluate(vae, ds, expdir=f'{logdir}/analysis', title=f'{dsname}{py}_{model_name}_{coef.name}') return ######## training best_llk_x = [] best_llk_y = [] def callback(): llk_x = [] llk_y = [] for x, y in valid.take(300): px, (qz, qy) = vae(x, training=False) llk_x.append(px.log_prob(x)) llk_y.append(qy.log_prob(y)) llk_x = tf.reduce_mean(tf.concat(llk_x, axis=0)) llk_y = tf.reduce_mean(tf.concat(llk_y, axis=0)) best_llk_x.append(llk_x) best_llk_y.append(llk_y) if llk_x >= np.max(best_llk_x): vae.save_weights(modelpath) vae.trainer.print(f'{model_name} {dsname} {py} ' f'best weights at iter#{vae.step.numpy()} ' f'llk_x={llk_x:.2f} llk_y={llk_y:.2f}') opt_info = get_optimizer_info(dsname) opt_info['max_iter'] += 20000 vae.fit( train, logdir=logdir, on_valid_end=callback, valid_interval=60, logging_interval=2, nan_gradients_policy='stop', compile_graph=True, skip_fitted=True, **opt_info, ) print(f'Trained {model_name} {dsname} {py} {vae.step.numpy()}(steps)')
def model_vae3(args: Arguments): return VariationalAutoencoder( **get_networks(args.ds, zdim=args.zdim, is_hierarchical=False, is_semi_supervised=False), free_bits=1.5)
def model_gmmvae3(args: Arguments): zdim = args.zdim prior = Independent(Normal(loc=tf.zeros([zdim]), scale=tf.ones([zdim])), 1) return GMMVAE(n_components=10, prior=prior, analytic=False, **get_networks(args.ds, zdim=args.zdim, is_hierarchical=False, is_semi_supervised=False))
def main(cfg: dict): assert cfg.vae is not None, \ ('No VAE model given, select one of the following: ' f"{', '.join(i.__name__.lower() for i in get_vae())}") assert cfg.ds is not None, \ ('No dataset given, select one of the following: ' 'mnist, dsprites, shapes3d, celeba, cortex, newsgroup20, newsgroup5, ...') ### load dataset ds = get_dataset(name=cfg.ds) ds_kw = dict(batch_size=batch_size, drop_remainder=True) ### path, save the output to the subfolder with dataset name output_dir = get_output_dir(subfolder=cfg.ds.lower()) gym_train_path = os.path.join(output_dir, 'gym_train') gym_valid_path = os.path.join(output_dir, 'gym_valid') gym_test_path = os.path.join(output_dir, 'gym_test') model_path = os.path.join(output_dir, 'model') ### prepare model init model = get_vae(cfg.vae) model_kw = inspect.getfullargspec(model.__init__).args[1:] model_kw = {k: v for k, v in cfg.items() if k in model_kw} is_semi_supervised = ds.has_labels and model.is_semi_supervised() if is_semi_supervised: train = ds.create_dataset(partition='train', label_percent=0.1, **ds_kw) valid = ds.create_dataset(partition='valid', label_percent=1.0, **ds_kw) else: train = ds.create_dataset(partition='train', label_percent=0., **ds_kw) valid = ds.create_dataset(partition='valid', label_percent=0., **ds_kw) ### create the model vae = model(path=model_path, **get_networks(cfg.ds, centerize_image=True, is_semi_supervised=is_semi_supervised, skip_generator=cfg.skip), **model_kw) vae.build((None, ) + ds.shape) vae.load_weights(raise_notfound=False, verbose=True) vae.early_stopping.mode = 'max' gym = create_gym(dsname=cfg.ds, vae=vae) ### fit the network def callback(): metrics = vae.trainer.last_valid_metrics llk = metrics['llk_image'] if 'llk_image' in metrics else metrics[ 'llk_dense_type'] vae.early_stopping.update(llk) signal = vae.early_stopping(verbose=True) if signal > 0: vae.save_weights(overwrite=True) # create the return metrics return dict(**gym.train()(prefix='train/', dpi=150), **gym.valid()(prefix='valid/', dpi=150)) ### evaluation if cfg.eval: vae.load_weights() gym.train() gym(save_path=gym_train_path, dpi=200, verbose=True) gym.valid() gym(save_path=gym_valid_path, dpi=200, verbose=True) gym.test() gym(save_path=gym_test_path, dpi=200, verbose=True) ### fit else: vae.early_stopping.patience = 10 vae.fit(train, valid=valid, epochs=-1, clipnorm=100, valid_interval=30, logging_interval=2, skip_fitted=True, on_valid_end=callback, logdir=output_dir, compile_graph=True, track_gradients=True, **get_optimizer_info(cfg.ds)) vae.early_stopping.plot_losses( path=os.path.join(output_dir, 'early_stopping.png')) vae.plot_learning_curves( os.path.join(output_dir, 'learning_curves.png'))
def model_equilibriumvae2(args: Arguments): return EquilibriumVAE(**get_networks(args.ds, zdim=args.zdim), C=1.0)