def on_create_model(self, cfg): x_rv = bay.RandomVariable( event_shape=self.input_shape, posterior='bernoulli', projection=False, name="Image", ) z_rv = bay.RandomVariable(event_shape=cfg.latent_size, posterior='diag', name="Latent") if cfg.vae in ('mutualinfovae', ): latent_size = cfg.latent_size * 2 else: latent_size = cfg.latent_size # create the network encoder, decoder = autoencoder.create_image_autoencoder( image_shape=self.input_shape, latent_shape=latent_size, distribution=x_rv.posterior) # create the model and criticizer self.model = autoencoder.get_vae(cfg.vae)(outputs=x_rv, latents=z_rv, encoder=encoder, decoder=decoder) self.z_samples = self.model.sample_prior(16, seed=1) self.criticizer = Criticizer(self.model, random_state=1)
def on_create_model(self, cfg, model_dir, md5): kwargs = dict( labels=bay.RandomVariable(self.labels_shape, self.labels_dist, True, name="Labels"), factors=bay.RandomVariable(cfg.zdim, 'diag', True, name="Factors"), ) kwargs.update(cfg.kwargs) # create the model model = autoencoder.get_vae(cfg.vae) kwargs = { k: v for k, v in kwargs.items() if k in inspect.getfullargspec(model.__init__).args } model = model(encoder=cfg.ds, outputs=bay.RandomVariable(self.images_shape, 'bernoulli', False, name="Images"), latents=bay.RandomVariable(cfg.zdim, 'diag', True, name="Latents"), **kwargs) self.model = model self.model.load_weights(model_dir)
def test_all_models(self): all_vae = autoencoder.get_vae() for vae_cls in all_vae: for latents in [ (autoencoder.RandomVariable(10, name="Latent1"), autoencoder.RandomVariable(10, name="Latent2")), autoencoder.RandomVariable(10, name="Latent1"), ]: for sample_shape in [ (), 2, (4, 2, 3), ]: vae_name = vae_cls.__name__ print(vae_name, sample_shape) try: if isinstance(vae_cls, autoencoder.VariationalAutoencoder): vae = vae_cls else: vae = vae_cls(latents=latents) params = vae.trainable_variables if hasattr(vae, 'discriminator'): disc_params = set( id(v) for v in vae.discriminator.trainable_variables) params = [ i for i in params if id(i) not in disc_params ] s = vae.sample_prior() px = vae.decode(s) x = vae.sample_data(5) with tf.GradientTape( watch_accessed_variables=False) as tape: tape.watch(params) px, qz = vae(x, sample_shape=sample_shape) elbo, llk, div = vae.elbo(x, px, qz) grads = tape.gradient(elbo, params) for p, g in zip(params, grads): assert g is not None, \ "Gradient is None, param:%s shape:%s" % (p.name, p.shape) g = g.numpy() assert np.all(np.logical_not(np.isnan(g))), \ "NaN gradient param:%s shape:%s" % (p.name, p.shape) assert np.all(np.isfinite(g)), \ "Infinite gradient param:%s shape:%s" % (p.name, p.shape) except Exception as e: raise e
def main(vae, ds, args, parser): n_labeled = SEMI_SETTING[ds] vae = get_vae(vae) ds = get_dataset(ds) batch_size = BS_SETTING[ds.name] assert isinstance( ds, ImageDataset), f'Only support image dataset but given {ds}' vae: Type[VariationalAutoencoder] ds: ImageDataset is_semi = vae.is_semi_supervised() ## skip unsupervised system, if there are semi-supervised modifications if not is_semi: for key in ('alpha', 'coef', 'ratio'): if parser.get_default(key) != getattr(args, key): print('Skip semi-supervised training for:', args) return ## prepare the arguments kw = {} ## path name = f'{ds.name}_{vae.__name__.lower()}' path = f'{ROOT}/{name}' anno = [] if args.zdim > 0: anno.append(f'z{int(args.zdim)}') if is_semi: anno.append(f'a{args.alpha:g}') anno.append(f'r{args.ratio:g}') kw['alpha'] = args.alpha if issubclass(vae, (SemafoBase, MIVAE)): anno.append(f'c{args.coef:g}') kw['mi_coef'] = args.coef if len(anno) > 0: path += f"_{'_'.join(anno)}" if args.override and os.path.exists(path): shutil.rmtree(path) print('Override:', path) if not os.path.exists(path): os.makedirs(path) print(path) ## data train = ds.create_dataset('train', batch_size=batch_size, label_percent=n_labeled if is_semi else False, oversample_ratio=args.ratio) valid = ds.create_dataset('valid', label_percent=1.0, batch_size=batch_size // 2) ## create model vae = vae( **kw, **get_networks(ds.name, zdim=int(args.zdim), is_semi_supervised=is_semi)) vae.build((None, ) + ds.shape) print(vae) vae.load_weights(f'{path}/model', verbose=True) best_llk = [] ## training def callback(): llk = [] y_true = [] y_pred = [] for x, y in tqdm(valid.take(500)): P, Q = vae(x, training=False) P = as_tuple(P) llk.append(P[0].log_prob(x)) if is_semi: y_true.append(np.argmax(y, -1)) y_pred.append(np.argmax(get_ymean(P[1]), -1)) # accuracy if is_semi: y_true = np.concatenate(y_true, axis=0) y_pred = np.concatenate(y_pred, axis=0) acc = accuracy_score(y_true=y_true, y_pred=y_pred) else: acc = 0 # log-likelihood llk = tf.reduce_mean(tf.concat(llk, 0)) best_llk.append(llk) text = f'#{vae.step.numpy()} llk={llk:.2f} acc={acc:.2f}' if llk >= np.max(best_llk): vae.save_weights(f'{path}/model') vae.trainer.print(f'best llk {text}') else: vae.trainer.print(f'worse llk {text}') # tensorboard summary tf.summary.scalar('llk_valid', llk) tf.summary.scalar('acc_valid', acc) optim_info = get_optimizer_info(ds.name, batch_size=batch_size) if args.it > 0: optim_info['max_iter'] = int(args.it) vae.fit( train, skip_fitted=True, logdir=path, on_valid_end=callback, clipnorm=100, logging_interval=10, valid_interval=180, nan_gradients_policy='stop', **optim_info, ) ## evaluating vae.load_weights(f'{path}/model', verbose=True) if args.eval: evaluate(vae, ds, path, f'{name}')