Ejemplo n.º 1
0
 def __init__(self, args: Arguments, free_bits=None, beta=1, **kwargs):
   networks = get_networks(args.ds, zdim=args.zdim,
                           is_hierarchical=False, is_semi_supervised=False)
   zdim = args.zdim
   prior = Normal(loc=tf.zeros([zdim]), scale=tf.ones([zdim]))
   networks['latents'] = DistributionDense(units=zdim * 2,
                                           posterior=make_normal, prior=prior,
                                           name=networks['latents'].name)
   super().__init__(free_bits=free_bits, beta=beta, **networks, **kwargs)
Ejemplo n.º 2
0
 def __init__(self, args: Arguments, **kwargs):
   networks = get_networks(args.ds, zdim=args.zdim, is_hierarchical=False,
                           is_semi_supervised=False)
   zdim = args.zdim
   prior = Gamma(rate=tf.fill([zdim], 0.3),
                 concentration=tf.fill([zdim], 0.3))
   networks['latents'] = DistributionDense(units=zdim * 2,
                                           posterior=make_gamma, prior=prior,
                                           name=networks['latents'].name)
   super().__init__(**networks, **kwargs)
Ejemplo n.º 3
0
def model_fullcov(args: Arguments):
  nets = get_networks(args.ds, zdim=args.zdim, is_hierarchical=False,
                      is_semi_supervised=False)
  zdims = int(np.prod(nets['latents'].event_shape))
  nets['latents'] = RVconf(
    event_shape=zdims,
    projection=True,
    posterior='mvntril',
    prior=Independent(Normal(tf.zeros([zdims]), tf.ones([zdims])), 1),
    name='latents').create_posterior()
  return VariationalAutoencoder(**nets, name='FullCov')
Ejemplo n.º 4
0
 def __init__(self, args: Arguments, free_bits=None, beta=1., **kwargs):
   networks = get_networks(args.ds, zdim=args.zdim, is_hierarchical=False,
                           is_semi_supervised=False)
   obs: DistributionDense = networks['observation']
   event_shape = obs.event_shape
   obs_new = DistributionDense(
     event_shape=event_shape,
     units=int(np.prod(event_shape)) * 2,
     posterior=partial(make_gaussian_out, event_shape=event_shape),
     name='image')
   networks['observation'] = obs_new
   super().__init__(free_bits=free_bits, beta=beta, **networks,
                    **kwargs)
Ejemplo n.º 5
0
 def __init__(self, args: Arguments, **kwargs):
   networks = get_networks(args.ds, zdim=args.zdim,
                           is_hierarchical=False,
                           is_semi_supervised=False)
   zdim = args.zdim
   prior = Normal(loc=tf.zeros([zdim]), scale=tf.ones([zdim]))
   latents = [
     DistributionDense(units=zdim * 2, posterior=make_normal, prior=prior,
                       name='latents1'),
     DistributionDense(units=zdim * 2, posterior=make_normal, prior=prior,
                       name='latents2')
   ]
   networks['latents'] = latents
   super().__init__(**networks, **kwargs)
Ejemplo n.º 6
0
def create_model(args: Namespace) -> VariationalModel:
    networks = get_networks(args.ds, zdim=args.zdim, is_semi_supervised=True)
    name = args.vae
    for k, v in globals().items():
        if (isinstance(v, type) and issubclass(v, VariationalModel)
                and name.lower() == k.lower()):
            spec = inspect.getfullargspec(v.__init__)
            networks = {
                k: v
                for k, v in networks.items()
                if k in spec.args or k in spec.kwonlyargs
            }
            model = v(**networks)
            model.build([None] + IMAGE_SHAPE)
            return model
    raise ValueError(f'Cannot find model with name: {name}')
Ejemplo n.º 7
0
def model_gmmprior(args: Arguments):
  nets = get_networks(args.ds, zdim=args.zdim, is_hierarchical=False,
                      is_semi_supervised=False)
  latent_size = np.prod(nets['latents'].event_shape)
  n_components = 100
  loc = tf.compat.v1.get_variable(name="loc", shape=[n_components, latent_size])
  raw_scale_diag = tf.compat.v1.get_variable(
    name="raw_scale_diag", shape=[n_components, latent_size])
  mixture_logits = tf.compat.v1.get_variable(
    name="mixture_logits", shape=[n_components])
  nets['latents'].prior = MixtureSameFamily(
    components_distribution=MultivariateNormalDiag(
      loc=loc,
      scale_diag=tf.nn.softplus(raw_scale_diag) + tf.math.exp(-7.)),
    mixture_distribution=Categorical(logits=mixture_logits),
    name="prior")
  return VariationalAutoencoder(**nets, name='GMMPrior')
Ejemplo n.º 8
0
def model_fullcovgmm(args: Arguments):
  nets = get_networks(args.ds, zdim=args.zdim, is_hierarchical=False,
                      is_semi_supervised=False)
  latent_size = int(np.prod(nets['latents'].event_shape))
  n_components = 100
  loc = tf.compat.v1.get_variable(name="loc", shape=[n_components, latent_size])
  raw_scale_diag = tf.compat.v1.get_variable(
    name="raw_scale_diag", shape=[n_components, latent_size])
  mixture_logits = tf.compat.v1.get_variable(
    name="mixture_logits", shape=[n_components])
  nets['latents'] = RVconf(
    event_shape=latent_size,
    projection=True,
    posterior='mvntril',
    prior=MixtureSameFamily(
      components_distribution=MultivariateNormalDiag(
        loc=loc,
        scale_diag=tf.nn.softplus(raw_scale_diag) + tf.math.exp(-7.)),
      mixture_distribution=Categorical(logits=mixture_logits),
      name="prior"),
    name='latents').create_posterior()
  return VariationalAutoencoder(**nets, name='FullCov')
def get_dense_networks(args: Arguments):
    networks = get_networks('mnist',
                            is_semi_supervised=False,
                            is_hierarchical=False,
                            zdim=args.zdim)
    networks['encoder'] = SequentialNetwork([
        InputLayer(input_shape=[28, 28, 1]),
        CenterAt0(),
        Flatten(),
        Dense(1024, activation='relu'),
        Dense(1024, activation='relu'),
        Dense(1024, activation='relu'),
    ],
                                            name='Encoder')
    networks['decoder'] = SequentialNetwork([
        InputLayer(input_shape=[args.zdim]),
        Dense(1024, activation='relu'),
        Dense(1024, activation='relu'),
        Dense(1024, activation='relu'),
        Dense(28 * 28 * 1, activation='linear'),
        Reshape([28, 28, 1]),
    ],
                                            name='Decoder')
    return networks
Ejemplo n.º 10
0
def get_model(
    args: Arguments,
    return_dataset: bool = True,
    encoder: Any = None,
    decoder: Any = None,
    latents: Any = None,
    observation: Any = None,
    **kwargs
) -> Union[VariationalModel, Tuple[VariationalModel, ImageDataset]]:
    ds = _DS[args.ds]
    vae_name = args.vae
    vae_cls = get_vae(vae_name)
    networks = get_networks(ds.name,
                            is_semi_supervised=vae_cls.is_semi_supervised(),
                            is_hierarchical=vae_cls.is_hierarchical(),
                            zdim=None if args.zdim < 1 else args.zdim)
    for k, v in locals().items():
        if k in networks and v is not None:
            networks[k] = v
    vae = vae_cls(**networks, **kwargs)
    vae.build(ds.full_shape)
    if return_dataset:
        return vae, ds
    return vae
Ejemplo n.º 11
0
def model_gmmvae2(args: Arguments):
  return GMMVAE(n_components=50,
                **get_networks(args.ds, zdim=args.zdim, is_hierarchical=False,
                               is_semi_supervised=False))
Ejemplo n.º 12
0
def model_iwgamma2(args: Arguments):
  return IWGammaVAE(gamma=2.0, n_iw=10, **get_networks(args.ds, zdim=args.zdim,
                                                       is_hierarchical=False,
                                                       is_semi_supervised=False))
Ejemplo n.º 13
0
def main(vae, ds, args, parser):
    n_labeled = SEMI_SETTING[ds]
    vae = get_vae(vae)
    ds = get_dataset(ds)
    batch_size = BS_SETTING[ds.name]
    assert isinstance(
        ds, ImageDataset), f'Only support image dataset but given {ds}'
    vae: Type[VariationalAutoencoder]
    ds: ImageDataset
    is_semi = vae.is_semi_supervised()
    ## skip unsupervised system, if there are semi-supervised modifications
    if not is_semi:
        for key in ('alpha', 'coef', 'ratio'):
            if parser.get_default(key) != getattr(args, key):
                print('Skip semi-supervised training for:', args)
                return
    ## prepare the arguments
    kw = {}
    ## path
    name = f'{ds.name}_{vae.__name__.lower()}'
    path = f'{ROOT}/{name}'
    anno = []
    if args.zdim > 0:
        anno.append(f'z{int(args.zdim)}')
    if is_semi:
        anno.append(f'a{args.alpha:g}')
        anno.append(f'r{args.ratio:g}')
        kw['alpha'] = args.alpha
    if issubclass(vae, (SemafoBase, MIVAE)):
        anno.append(f'c{args.coef:g}')
        kw['mi_coef'] = args.coef
    if len(anno) > 0:
        path += f"_{'_'.join(anno)}"
    if args.override and os.path.exists(path):
        shutil.rmtree(path)
        print('Override:', path)
    if not os.path.exists(path):
        os.makedirs(path)
    print(path)
    ## data
    train = ds.create_dataset('train',
                              batch_size=batch_size,
                              label_percent=n_labeled if is_semi else False,
                              oversample_ratio=args.ratio)
    valid = ds.create_dataset('valid',
                              label_percent=1.0,
                              batch_size=batch_size // 2)
    ## create model
    vae = vae(
        **kw,
        **get_networks(ds.name,
                       zdim=int(args.zdim),
                       is_semi_supervised=is_semi))
    vae.build((None, ) + ds.shape)
    print(vae)
    vae.load_weights(f'{path}/model', verbose=True)
    best_llk = []

    ## training
    def callback():
        llk = []
        y_true = []
        y_pred = []
        for x, y in tqdm(valid.take(500)):
            P, Q = vae(x, training=False)
            P = as_tuple(P)
            llk.append(P[0].log_prob(x))
            if is_semi:
                y_true.append(np.argmax(y, -1))
                y_pred.append(np.argmax(get_ymean(P[1]), -1))
        # accuracy
        if is_semi:
            y_true = np.concatenate(y_true, axis=0)
            y_pred = np.concatenate(y_pred, axis=0)
            acc = accuracy_score(y_true=y_true, y_pred=y_pred)
        else:
            acc = 0
        # log-likelihood
        llk = tf.reduce_mean(tf.concat(llk, 0))
        best_llk.append(llk)
        text = f'#{vae.step.numpy()} llk={llk:.2f} acc={acc:.2f}'
        if llk >= np.max(best_llk):
            vae.save_weights(f'{path}/model')
            vae.trainer.print(f'best llk {text}')
        else:
            vae.trainer.print(f'worse llk {text}')
        # tensorboard summary
        tf.summary.scalar('llk_valid', llk)
        tf.summary.scalar('acc_valid', acc)

    optim_info = get_optimizer_info(ds.name, batch_size=batch_size)
    if args.it > 0:
        optim_info['max_iter'] = int(args.it)
    vae.fit(
        train,
        skip_fitted=True,
        logdir=path,
        on_valid_end=callback,
        clipnorm=100,
        logging_interval=10,
        valid_interval=180,
        nan_gradients_policy='stop',
        **optim_info,
    )
    ## evaluating
    vae.load_weights(f'{path}/model', verbose=True)
    if args.eval:
        evaluate(vae, ds, path, f'{name}')
Ejemplo n.º 14
0
def model_equilibriumvae8(args: Arguments):
  return EquilibriumVAE(
    **get_networks(args.ds, zdim=args.zdim, is_hierarchical=False,
                   is_semi_supervised=False),
    R=0.0, C=1.0, random_capacity=True, dropout=0.0,
    beta=1.)
Ejemplo n.º 15
0
def model_equilibriumvae9(args: Arguments):
  return EquilibriumVAE(
    **get_networks(args.ds, zdim=args.zdim, is_hierarchical=False,
                   is_semi_supervised=False),
    R=-0.5, C=0., dropout=0., beta=1.)
Ejemplo n.º 16
0
def model_equilibriumvae3(args: Arguments):
  return EquilibriumVAE(
    **get_networks(args.ds, zdim=args.zdim, is_hierarchical=False,
                   is_semi_supervised=False),
    C=1.5)
Ejemplo n.º 17
0
def model_rvae(args: Arguments):
  return VariationalAutoencoder(
    **get_networks(args.ds, zdim=args.zdim, is_hierarchical=False,
                   is_semi_supervised=False),
    reverse=True)
Ejemplo n.º 18
0
def model_gvae4(args: Arguments):
  nets = get_networks(args.ds, zdim=args.zdim, is_hierarchical=False,
                      is_semi_supervised=False)
  return BetaGammaVAE(**nets, beta=2.0, gamma=5.0)
Ejemplo n.º 19
0
def model_bvae3(args: Arguments):
  nets = get_networks(args.ds, zdim=args.zdim, is_hierarchical=False,
                      is_semi_supervised=False)
  return BetaVAE(**nets, beta=0.5)
Ejemplo n.º 20
0
def model_bcvae2(args: Arguments):
  nets = get_networks(args.ds, zdim=args.zdim, is_hierarchical=False,
                      is_semi_supervised=False)
  return BetaCapacityVAE(**nets, c_min=0.01, c_max=25, gamma=5, n_steps=60000)
Ejemplo n.º 21
0
from odin.bay import TwoStageVAE, plot_latent_stats
from odin.fuel import MNIST
from odin.networks import get_networks
from odin import visual as vs
from tqdm import tqdm
from odin.ml import fast_tsne

ds = MNIST()
train = ds.create_dataset('train', batch_size=32)
valid = ds.create_dataset('valid',
                          batch_size=36,
                          label_percent=1.0,
                          drop_remainder=True)

vae = TwoStageVAE(**get_networks(ds.name))
vae.build(ds.full_shape)
if True:
    vae.load_weights('/tmp/twostagevae', verbose=True, raise_notfound=True)
else:
    vae.fit(train, learning_rate=1e-3, max_iter=300000)
    vae.save_weights('/tmp/twostagevae')
    exit()

Z = []
U = []
Z_hat = []
Y = []
for x, y in tqdm(valid):
    qz_x, qu_z, qz_u = vae.encode_two_stages(x)
    Z.append(qz_x.mean())
Ejemplo n.º 22
0
def run_task(args, evaluation=False):
    from odin.fuel import get_dataset
    from odin.networks import get_networks, get_optimizer_info
    from odin.bay.vi import DisentanglementGym, Correlation, DimReduce
    ######## arguments
    model: SemafoVAE = args['model']
    dsname: str = args['ds']
    py: float = args['py']
    coef: Interpolation = args['coef']
    model_name = model.__name__.lower()
    ######## prepare path
    logdir = f'{outdir}/{dsname}_{py}/{model_name}_{coef.name}'
    if OVERRIDE and os.path.exists(logdir) and not evaluation:
        shutil.rmtree(logdir)
        print(f'Override model at path: {logdir}')
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    modelpath = f'{logdir}/model'
    ######## dataset
    ds = get_dataset(dsname)
    tanh_norm = True if dsname == 'celeba' else False
    train = ds.create_dataset('train',
                              batch_size=32,
                              label_percent=py,
                              normalize='tanh' if tanh_norm else 'probs')
    valid = ds.create_dataset(
        'valid',
        batch_size=32,
        label_percent=True,
        normalize='tanh' if tanh_norm else 'probs').shuffle(
            1000, seed=1, reshuffle_each_iteration=True)
    ######## model
    networks = get_networks(dsname,
                            centerize_image=False if tanh_norm else True,
                            is_semi_supervised=True)
    vae = model(alpha=1. / py, mi_coef=coef, **networks)
    vae.build((None, ) + ds.shape)
    vae.load_weights(modelpath, verbose=True)

    ######## evaluation
    if evaluation:
        if vae.step.numpy() <= 1:
            return
        evaluate(vae,
                 ds,
                 expdir=f'{logdir}/analysis',
                 title=f'{dsname}{py}_{model_name}_{coef.name}')
        return

    ######## training
    best_llk_x = []
    best_llk_y = []

    def callback():
        llk_x = []
        llk_y = []
        for x, y in valid.take(300):
            px, (qz, qy) = vae(x, training=False)
            llk_x.append(px.log_prob(x))
            llk_y.append(qy.log_prob(y))
        llk_x = tf.reduce_mean(tf.concat(llk_x, axis=0))
        llk_y = tf.reduce_mean(tf.concat(llk_y, axis=0))
        best_llk_x.append(llk_x)
        best_llk_y.append(llk_y)
        if llk_x >= np.max(best_llk_x):
            vae.save_weights(modelpath)
            vae.trainer.print(f'{model_name} {dsname} {py} '
                              f'best weights at iter#{vae.step.numpy()} '
                              f'llk_x={llk_x:.2f} llk_y={llk_y:.2f}')

    opt_info = get_optimizer_info(dsname)
    opt_info['max_iter'] += 20000
    vae.fit(
        train,
        logdir=logdir,
        on_valid_end=callback,
        valid_interval=60,
        logging_interval=2,
        nan_gradients_policy='stop',
        compile_graph=True,
        skip_fitted=True,
        **opt_info,
    )
    print(f'Trained {model_name} {dsname} {py} {vae.step.numpy()}(steps)')
Ejemplo n.º 23
0
def model_vae3(args: Arguments):
  return VariationalAutoencoder(
    **get_networks(args.ds, zdim=args.zdim, is_hierarchical=False,
                   is_semi_supervised=False),
    free_bits=1.5)
Ejemplo n.º 24
0
def model_gmmvae3(args: Arguments):
  zdim = args.zdim
  prior = Independent(Normal(loc=tf.zeros([zdim]), scale=tf.ones([zdim])), 1)
  return GMMVAE(n_components=10, prior=prior, analytic=False,
                **get_networks(args.ds, zdim=args.zdim, is_hierarchical=False,
                               is_semi_supervised=False))
Ejemplo n.º 25
0
def main(cfg: dict):
    assert cfg.vae is not None, \
      ('No VAE model given, select one of the following: '
       f"{', '.join(i.__name__.lower() for i in get_vae())}")
    assert cfg.ds is not None, \
      ('No dataset given, select one of the following: '
       'mnist, dsprites, shapes3d, celeba, cortex, newsgroup20, newsgroup5, ...')
    ### load dataset
    ds = get_dataset(name=cfg.ds)
    ds_kw = dict(batch_size=batch_size, drop_remainder=True)
    ### path, save the output to the subfolder with dataset name
    output_dir = get_output_dir(subfolder=cfg.ds.lower())
    gym_train_path = os.path.join(output_dir, 'gym_train')
    gym_valid_path = os.path.join(output_dir, 'gym_valid')
    gym_test_path = os.path.join(output_dir, 'gym_test')
    model_path = os.path.join(output_dir, 'model')
    ### prepare model init
    model = get_vae(cfg.vae)
    model_kw = inspect.getfullargspec(model.__init__).args[1:]
    model_kw = {k: v for k, v in cfg.items() if k in model_kw}
    is_semi_supervised = ds.has_labels and model.is_semi_supervised()
    if is_semi_supervised:
        train = ds.create_dataset(partition='train',
                                  label_percent=0.1,
                                  **ds_kw)
        valid = ds.create_dataset(partition='valid',
                                  label_percent=1.0,
                                  **ds_kw)
    else:
        train = ds.create_dataset(partition='train', label_percent=0., **ds_kw)
        valid = ds.create_dataset(partition='valid', label_percent=0., **ds_kw)
    ### create the model
    vae = model(path=model_path,
                **get_networks(cfg.ds,
                               centerize_image=True,
                               is_semi_supervised=is_semi_supervised,
                               skip_generator=cfg.skip),
                **model_kw)
    vae.build((None, ) + ds.shape)
    vae.load_weights(raise_notfound=False, verbose=True)
    vae.early_stopping.mode = 'max'
    gym = create_gym(dsname=cfg.ds, vae=vae)

    ### fit the network
    def callback():
        metrics = vae.trainer.last_valid_metrics
        llk = metrics['llk_image'] if 'llk_image' in metrics else metrics[
            'llk_dense_type']
        vae.early_stopping.update(llk)
        signal = vae.early_stopping(verbose=True)
        if signal > 0:
            vae.save_weights(overwrite=True)
        # create the return metrics
        return dict(**gym.train()(prefix='train/', dpi=150),
                    **gym.valid()(prefix='valid/', dpi=150))

    ### evaluation
    if cfg.eval:
        vae.load_weights()
        gym.train()
        gym(save_path=gym_train_path, dpi=200, verbose=True)
        gym.valid()
        gym(save_path=gym_valid_path, dpi=200, verbose=True)
        gym.test()
        gym(save_path=gym_test_path, dpi=200, verbose=True)
    ### fit
    else:
        vae.early_stopping.patience = 10
        vae.fit(train,
                valid=valid,
                epochs=-1,
                clipnorm=100,
                valid_interval=30,
                logging_interval=2,
                skip_fitted=True,
                on_valid_end=callback,
                logdir=output_dir,
                compile_graph=True,
                track_gradients=True,
                **get_optimizer_info(cfg.ds))
        vae.early_stopping.plot_losses(
            path=os.path.join(output_dir, 'early_stopping.png'))
        vae.plot_learning_curves(
            os.path.join(output_dir, 'learning_curves.png'))
Ejemplo n.º 26
0
def model_equilibriumvae2(args: Arguments):
  return EquilibriumVAE(**get_networks(args.ds, zdim=args.zdim),
                        C=1.0)