Пример #1
0
 def on_create_model(self, cfg):
     x_rv = bay.RandomVariable(
         event_shape=self.input_shape,
         posterior='bernoulli',
         projection=False,
         name="Image",
     )
     z_rv = bay.RandomVariable(event_shape=cfg.latent_size,
                               posterior='diag',
                               name="Latent")
     if cfg.vae in ('mutualinfovae', ):
         latent_size = cfg.latent_size * 2
     else:
         latent_size = cfg.latent_size
     # create the network
     encoder, decoder = autoencoder.create_image_autoencoder(
         image_shape=self.input_shape,
         latent_shape=latent_size,
         distribution=x_rv.posterior)
     # create the model and criticizer
     self.model = autoencoder.get_vae(cfg.vae)(outputs=x_rv,
                                               latents=z_rv,
                                               encoder=encoder,
                                               decoder=decoder)
     self.z_samples = self.model.sample_prior(16, seed=1)
     self.criticizer = Criticizer(self.model, random_state=1)
Пример #2
0
 def on_create_model(self, cfg, model_dir, md5):
     kwargs = dict(
         labels=bay.RandomVariable(self.labels_shape,
                                   self.labels_dist,
                                   True,
                                   name="Labels"),
         factors=bay.RandomVariable(cfg.zdim, 'diag', True, name="Factors"),
     )
     kwargs.update(cfg.kwargs)
     # create the model
     model = autoencoder.get_vae(cfg.vae)
     kwargs = {
         k: v
         for k, v in kwargs.items()
         if k in inspect.getfullargspec(model.__init__).args
     }
     model = model(encoder=cfg.ds,
                   outputs=bay.RandomVariable(self.images_shape,
                                              'bernoulli',
                                              False,
                                              name="Images"),
                   latents=bay.RandomVariable(cfg.zdim,
                                              'diag',
                                              True,
                                              name="Latents"),
                   **kwargs)
     self.model = model
     self.model.load_weights(model_dir)
Пример #3
0
 def test_all_models(self):
     all_vae = autoencoder.get_vae()
     for vae_cls in all_vae:
         for latents in [
             (autoencoder.RandomVariable(10, name="Latent1"),
              autoencoder.RandomVariable(10, name="Latent2")),
                 autoencoder.RandomVariable(10, name="Latent1"),
         ]:
             for sample_shape in [
                 (),
                     2,
                 (4, 2, 3),
             ]:
                 vae_name = vae_cls.__name__
                 print(vae_name, sample_shape)
                 try:
                     if isinstance(vae_cls,
                                   autoencoder.VariationalAutoencoder):
                         vae = vae_cls
                     else:
                         vae = vae_cls(latents=latents)
                     params = vae.trainable_variables
                     if hasattr(vae, 'discriminator'):
                         disc_params = set(
                             id(v)
                             for v in vae.discriminator.trainable_variables)
                         params = [
                             i for i in params if id(i) not in disc_params
                         ]
                     s = vae.sample_prior()
                     px = vae.decode(s)
                     x = vae.sample_data(5)
                     with tf.GradientTape(
                             watch_accessed_variables=False) as tape:
                         tape.watch(params)
                         px, qz = vae(x, sample_shape=sample_shape)
                         elbo, llk, div = vae.elbo(x, px, qz)
                     grads = tape.gradient(elbo, params)
                     for p, g in zip(params, grads):
                         assert g is not None, \
                             "Gradient is None, param:%s shape:%s" % (p.name, p.shape)
                         g = g.numpy()
                         assert np.all(np.logical_not(np.isnan(g))), \
                                         "NaN gradient param:%s shape:%s" % (p.name, p.shape)
                         assert np.all(np.isfinite(g)), \
                             "Infinite gradient param:%s shape:%s" % (p.name, p.shape)
                 except Exception as e:
                     raise e
Пример #4
0
def main(vae, ds, args, parser):
    n_labeled = SEMI_SETTING[ds]
    vae = get_vae(vae)
    ds = get_dataset(ds)
    batch_size = BS_SETTING[ds.name]
    assert isinstance(
        ds, ImageDataset), f'Only support image dataset but given {ds}'
    vae: Type[VariationalAutoencoder]
    ds: ImageDataset
    is_semi = vae.is_semi_supervised()
    ## skip unsupervised system, if there are semi-supervised modifications
    if not is_semi:
        for key in ('alpha', 'coef', 'ratio'):
            if parser.get_default(key) != getattr(args, key):
                print('Skip semi-supervised training for:', args)
                return
    ## prepare the arguments
    kw = {}
    ## path
    name = f'{ds.name}_{vae.__name__.lower()}'
    path = f'{ROOT}/{name}'
    anno = []
    if args.zdim > 0:
        anno.append(f'z{int(args.zdim)}')
    if is_semi:
        anno.append(f'a{args.alpha:g}')
        anno.append(f'r{args.ratio:g}')
        kw['alpha'] = args.alpha
    if issubclass(vae, (SemafoBase, MIVAE)):
        anno.append(f'c{args.coef:g}')
        kw['mi_coef'] = args.coef
    if len(anno) > 0:
        path += f"_{'_'.join(anno)}"
    if args.override and os.path.exists(path):
        shutil.rmtree(path)
        print('Override:', path)
    if not os.path.exists(path):
        os.makedirs(path)
    print(path)
    ## data
    train = ds.create_dataset('train',
                              batch_size=batch_size,
                              label_percent=n_labeled if is_semi else False,
                              oversample_ratio=args.ratio)
    valid = ds.create_dataset('valid',
                              label_percent=1.0,
                              batch_size=batch_size // 2)
    ## create model
    vae = vae(
        **kw,
        **get_networks(ds.name,
                       zdim=int(args.zdim),
                       is_semi_supervised=is_semi))
    vae.build((None, ) + ds.shape)
    print(vae)
    vae.load_weights(f'{path}/model', verbose=True)
    best_llk = []

    ## training
    def callback():
        llk = []
        y_true = []
        y_pred = []
        for x, y in tqdm(valid.take(500)):
            P, Q = vae(x, training=False)
            P = as_tuple(P)
            llk.append(P[0].log_prob(x))
            if is_semi:
                y_true.append(np.argmax(y, -1))
                y_pred.append(np.argmax(get_ymean(P[1]), -1))
        # accuracy
        if is_semi:
            y_true = np.concatenate(y_true, axis=0)
            y_pred = np.concatenate(y_pred, axis=0)
            acc = accuracy_score(y_true=y_true, y_pred=y_pred)
        else:
            acc = 0
        # log-likelihood
        llk = tf.reduce_mean(tf.concat(llk, 0))
        best_llk.append(llk)
        text = f'#{vae.step.numpy()} llk={llk:.2f} acc={acc:.2f}'
        if llk >= np.max(best_llk):
            vae.save_weights(f'{path}/model')
            vae.trainer.print(f'best llk {text}')
        else:
            vae.trainer.print(f'worse llk {text}')
        # tensorboard summary
        tf.summary.scalar('llk_valid', llk)
        tf.summary.scalar('acc_valid', acc)

    optim_info = get_optimizer_info(ds.name, batch_size=batch_size)
    if args.it > 0:
        optim_info['max_iter'] = int(args.it)
    vae.fit(
        train,
        skip_fitted=True,
        logdir=path,
        on_valid_end=callback,
        clipnorm=100,
        logging_interval=10,
        valid_interval=180,
        nan_gradients_policy='stop',
        **optim_info,
    )
    ## evaluating
    vae.load_weights(f'{path}/model', verbose=True)
    if args.eval:
        evaluate(vae, ds, path, f'{name}')