Ejemplo n.º 1
0
def load_data(name: str, batch_size: int):
    dataset = get_dataset(name)()
    assert dataset.has_labels, f'No labels for given dataset {name}'
    kw = dict(batch_size=batch_size, drop_remainder=True)
    test_l = dataset.create_dataset(partition='test', inc_labels=1.0, **kw)
    sample_images, y = [(x[:16], y[:16]) for x, y in test_l.take(1)][0]
    # inputs structure
    images, labels = tf.data.experimental.get_structure(test_l)
    images_shape = images.shape[1:]
    labels_shape = labels.shape[1:]
    return dataset, (sample_images, y), (images_shape, labels_shape)
Ejemplo n.º 2
0
def create_vae(args) -> Tuple[VariationalAutoencoder, ImageDataset]:
    dist = 'bernoulli' if args.ds == 'shapes3d' else 'qlogistic'
    key = f'model_{args.model}'
    if key not in globals():
        raise ValueError(f'Cannot find model with name: {args.model}')
    model = globals()[key](args.zdim, dist)
    model: VariationalAutoencoder
    model.build([None] + IMAGE_SHAPE)
    ds = get_dataset(args.ds)
    ds: ImageDataset
    return model, ds
Ejemplo n.º 3
0
def main(args: Namespace):
    # === 1. get data and set metadata
    ds = get_dataset(args.ds)
    global IMAGE_SHAPE, ZDIM, DS, N_LABELS
    ZDIM = int(args.zdim)
    IMAGE_SHAPE = list(ds.shape)
    DS = ds
    N_LABELS = float(args.y)
    # === 2. create model
    vae = create_model(args)
    if args.eval:
        vae.load_weights(get_model_path())
        evaluate(vae, args)
    else:
        train(vae, ds, args, label_percent=args.y, oversample_ratio=0.5)
Ejemplo n.º 4
0
def main(args: Arguments):
  # === 0. set configs
  ds = get_dataset(args.ds)
  # === 1. get model
  model = None
  for k, v in globals().items():
    if inspect.isfunction(v) and 'model_' == k[:6] and \
        k.split('_')[-1] == args.vae:
      model = v(args)
      model.build(ds.full_shape)
      break
  if model is None:
    model = get_model(args, return_dataset=False)
  # === 2. eval
  if args.eval:
    model.load_weights(get_model_path(args), raise_notfound=True, verbose=True)
    evaluate(model, ds, args)
  # === 3. train
  else:
    train(model, ds, args)
Ejemplo n.º 5
0
 def on_load_data(self, cfg):
   ds = get_dataset(cfg.ds)()
   ds.sample_images(save_path=os.path.join(self.save_path, 'samples.png'))
   kw = dict(batch_size=128, drop_remainder=True)
   train = ds.create_dataset(partition='train',
                             inc_labels=float(cfg.semi),
                             **kw)
   train_u = ds.create_dataset(partition='train', inc_labels=False, **kw)
   valid = ds.create_dataset(partition='valid', inc_labels=1.0, **kw)
   valid_u = ds.create_dataset(partition='valid', inc_labels=False, **kw)
   # reduce batch_size here, otherwise, mllk take ~ 7GB VRAM
   kw['batch_size'] = 8
   test = ds.create_dataset(partition='test', inc_labels=1.0, **kw)
   test_u = ds.create_dataset(partition='test', inc_labels=False, **kw)
   self.ds = ds
   self.train, self.train_u = train, train_u
   self.valid, self.valid_u = valid, valid_u
   self.test, self.test_u = test, test_u
   if cfg.verbose:
     print("Dataset:", ds)
     print(" train:", train)
     print(" train_u:", train_u)
Ejemplo n.º 6
0
 def on_load_data(self, cfg):
     dataset = get_dataset(cfg.ds)()
     train = dataset.create_dataset(partition='train', inc_labels=False)
     valid = dataset.create_dataset(partition='valid', inc_labels=False)
     test = dataset.create_dataset(partition='test', inc_labels=True)
     # sample
     x_valid = [x for x in valid.take(1)][0][:16]
     self.x_test = [xy[0] for xy in test.take(1)][0][:16]
     ### input description
     input_spec = tf.data.experimental.get_structure(train)
     fig = plt.figure(figsize=(12, 12))
     for i, x in enumerate(x_valid.numpy()):
         if x.shape[-1] == 1:
             x = np.squeeze(x, axis=-1)
         plt.subplot(4, 4, i + 1)
         plt.imshow(x, cmap='Greys_r' if x.ndim == 2 else None)
         plt.axis('off')
     plt.tight_layout()
     fig.savefig(os.path.join(self.save_path, '%s.pdf' % cfg.ds))
     plt.close(fig)
     ### store
     self.input_dtype = input_spec.dtype
     self.input_shape = input_spec.shape[1:]
     self.train, self.valid, self.test = train, valid, test
Ejemplo n.º 7
0
 def on_load_data(self, cfg):
     self.dataset = get_dataset(cfg.ds)()
     kw = dict(batch_size=cfg.batch_size, drop_remainder=True)
     self.train_u = self.dataset.create_dataset(partition='train',
                                                inc_labels=False,
                                                **kw)
     self.valid_u = self.dataset.create_dataset(partition='valid',
                                                inc_labels=False,
                                                **kw)
     self.test_u = self.dataset.create_dataset(partition='test',
                                               inc_labels=True,
                                               **kw)
     self.train_l = self.dataset.create_dataset(partition='train',
                                                inc_labels=0.1,
                                                **kw)
     self.valid_l = self.dataset.create_dataset(partition='valid',
                                                inc_labels=1.0,
                                                **kw)
     self.test_l = self.dataset.create_dataset(partition='test',
                                               inc_labels=1.0,
                                               **kw)
     # sample
     self.sample_images, y = [(x[:16], y[:16])
                              for x, y in self.test_l.take(1)][0]
     if np.any(np.sum(y, axis=1) > 1):
         if np.any(y > 1):
             self.labels_dist = "nb"  # negative binomial
         else:
             self.labels_dist = "bernoulli"
     else:
         self.labels_dist = "onehot"
     # inputs structure
     images, labels = tf.data.experimental.get_structure(
         self.train_l)['inputs']
     self.images_shape = images.shape[1:]
     self.labels_shape = labels.shape[1:]
Ejemplo n.º 8
0
    'set_cfg',
    'get_output_dir',
    'get_model',
    'get_model_path',
    'get_results_path',
    'get_args',
    'run_multi',
    'train',
]

_root_path: str = '/tmp/model'
_logging_interval: float = 5.
_valid_interval: float = 80
_n_valid_batches: int = 400
_extra_path = {}
_DS: Dict[str, ImageDataset] = defaultdictkey(lambda name: get_dataset(name))


# ===========================================================================
# Helper for evaluation
# ===========================================================================
@dataclasses.dataclass
class Arguments:
    vae: str = ''
    ds: str = ''
    zdim: Union[str, int] = 32
    it: int = 80000
    bs: int = 32
    clipnorm: float = 100.
    dpi: int = 100
    points: int = 4000
Ejemplo n.º 9
0
def run_task(args, evaluation=False):
    from odin.fuel import get_dataset
    from odin.networks import get_networks, get_optimizer_info
    from odin.bay.vi import DisentanglementGym, Correlation, DimReduce
    ######## arguments
    model: SemafoVAE = args['model']
    dsname: str = args['ds']
    py: float = args['py']
    coef: Interpolation = args['coef']
    model_name = model.__name__.lower()
    ######## prepare path
    logdir = f'{outdir}/{dsname}_{py}/{model_name}_{coef.name}'
    if OVERRIDE and os.path.exists(logdir) and not evaluation:
        shutil.rmtree(logdir)
        print(f'Override model at path: {logdir}')
    if not os.path.exists(logdir):
        os.makedirs(logdir)
    modelpath = f'{logdir}/model'
    ######## dataset
    ds = get_dataset(dsname)
    tanh_norm = True if dsname == 'celeba' else False
    train = ds.create_dataset('train',
                              batch_size=32,
                              label_percent=py,
                              normalize='tanh' if tanh_norm else 'probs')
    valid = ds.create_dataset(
        'valid',
        batch_size=32,
        label_percent=True,
        normalize='tanh' if tanh_norm else 'probs').shuffle(
            1000, seed=1, reshuffle_each_iteration=True)
    ######## model
    networks = get_networks(dsname,
                            centerize_image=False if tanh_norm else True,
                            is_semi_supervised=True)
    vae = model(alpha=1. / py, mi_coef=coef, **networks)
    vae.build((None, ) + ds.shape)
    vae.load_weights(modelpath, verbose=True)

    ######## evaluation
    if evaluation:
        if vae.step.numpy() <= 1:
            return
        evaluate(vae,
                 ds,
                 expdir=f'{logdir}/analysis',
                 title=f'{dsname}{py}_{model_name}_{coef.name}')
        return

    ######## training
    best_llk_x = []
    best_llk_y = []

    def callback():
        llk_x = []
        llk_y = []
        for x, y in valid.take(300):
            px, (qz, qy) = vae(x, training=False)
            llk_x.append(px.log_prob(x))
            llk_y.append(qy.log_prob(y))
        llk_x = tf.reduce_mean(tf.concat(llk_x, axis=0))
        llk_y = tf.reduce_mean(tf.concat(llk_y, axis=0))
        best_llk_x.append(llk_x)
        best_llk_y.append(llk_y)
        if llk_x >= np.max(best_llk_x):
            vae.save_weights(modelpath)
            vae.trainer.print(f'{model_name} {dsname} {py} '
                              f'best weights at iter#{vae.step.numpy()} '
                              f'llk_x={llk_x:.2f} llk_y={llk_y:.2f}')

    opt_info = get_optimizer_info(dsname)
    opt_info['max_iter'] += 20000
    vae.fit(
        train,
        logdir=logdir,
        on_valid_end=callback,
        valid_interval=60,
        logging_interval=2,
        nan_gradients_policy='stop',
        compile_graph=True,
        skip_fitted=True,
        **opt_info,
    )
    print(f'Trained {model_name} {dsname} {py} {vae.step.numpy()}(steps)')
Ejemplo n.º 10
0
def main(vae, ds, args, parser):
    n_labeled = SEMI_SETTING[ds]
    vae = get_vae(vae)
    ds = get_dataset(ds)
    batch_size = BS_SETTING[ds.name]
    assert isinstance(
        ds, ImageDataset), f'Only support image dataset but given {ds}'
    vae: Type[VariationalAutoencoder]
    ds: ImageDataset
    is_semi = vae.is_semi_supervised()
    ## skip unsupervised system, if there are semi-supervised modifications
    if not is_semi:
        for key in ('alpha', 'coef', 'ratio'):
            if parser.get_default(key) != getattr(args, key):
                print('Skip semi-supervised training for:', args)
                return
    ## prepare the arguments
    kw = {}
    ## path
    name = f'{ds.name}_{vae.__name__.lower()}'
    path = f'{ROOT}/{name}'
    anno = []
    if args.zdim > 0:
        anno.append(f'z{int(args.zdim)}')
    if is_semi:
        anno.append(f'a{args.alpha:g}')
        anno.append(f'r{args.ratio:g}')
        kw['alpha'] = args.alpha
    if issubclass(vae, (SemafoBase, MIVAE)):
        anno.append(f'c{args.coef:g}')
        kw['mi_coef'] = args.coef
    if len(anno) > 0:
        path += f"_{'_'.join(anno)}"
    if args.override and os.path.exists(path):
        shutil.rmtree(path)
        print('Override:', path)
    if not os.path.exists(path):
        os.makedirs(path)
    print(path)
    ## data
    train = ds.create_dataset('train',
                              batch_size=batch_size,
                              label_percent=n_labeled if is_semi else False,
                              oversample_ratio=args.ratio)
    valid = ds.create_dataset('valid',
                              label_percent=1.0,
                              batch_size=batch_size // 2)
    ## create model
    vae = vae(
        **kw,
        **get_networks(ds.name,
                       zdim=int(args.zdim),
                       is_semi_supervised=is_semi))
    vae.build((None, ) + ds.shape)
    print(vae)
    vae.load_weights(f'{path}/model', verbose=True)
    best_llk = []

    ## training
    def callback():
        llk = []
        y_true = []
        y_pred = []
        for x, y in tqdm(valid.take(500)):
            P, Q = vae(x, training=False)
            P = as_tuple(P)
            llk.append(P[0].log_prob(x))
            if is_semi:
                y_true.append(np.argmax(y, -1))
                y_pred.append(np.argmax(get_ymean(P[1]), -1))
        # accuracy
        if is_semi:
            y_true = np.concatenate(y_true, axis=0)
            y_pred = np.concatenate(y_pred, axis=0)
            acc = accuracy_score(y_true=y_true, y_pred=y_pred)
        else:
            acc = 0
        # log-likelihood
        llk = tf.reduce_mean(tf.concat(llk, 0))
        best_llk.append(llk)
        text = f'#{vae.step.numpy()} llk={llk:.2f} acc={acc:.2f}'
        if llk >= np.max(best_llk):
            vae.save_weights(f'{path}/model')
            vae.trainer.print(f'best llk {text}')
        else:
            vae.trainer.print(f'worse llk {text}')
        # tensorboard summary
        tf.summary.scalar('llk_valid', llk)
        tf.summary.scalar('acc_valid', acc)

    optim_info = get_optimizer_info(ds.name, batch_size=batch_size)
    if args.it > 0:
        optim_info['max_iter'] = int(args.it)
    vae.fit(
        train,
        skip_fitted=True,
        logdir=path,
        on_valid_end=callback,
        clipnorm=100,
        logging_interval=10,
        valid_interval=180,
        nan_gradients_policy='stop',
        **optim_info,
    )
    ## evaluating
    vae.load_weights(f'{path}/model', verbose=True)
    if args.eval:
        evaluate(vae, ds, path, f'{name}')
Ejemplo n.º 11
0
def main(cfg: dict):
    assert cfg.vae is not None, \
      ('No VAE model given, select one of the following: '
       f"{', '.join(i.__name__.lower() for i in get_vae())}")
    assert cfg.ds is not None, \
      ('No dataset given, select one of the following: '
       'mnist, dsprites, shapes3d, celeba, cortex, newsgroup20, newsgroup5, ...')
    ### load dataset
    ds = get_dataset(name=cfg.ds)
    ds_kw = dict(batch_size=batch_size, drop_remainder=True)
    ### path, save the output to the subfolder with dataset name
    output_dir = get_output_dir(subfolder=cfg.ds.lower())
    gym_train_path = os.path.join(output_dir, 'gym_train')
    gym_valid_path = os.path.join(output_dir, 'gym_valid')
    gym_test_path = os.path.join(output_dir, 'gym_test')
    model_path = os.path.join(output_dir, 'model')
    ### prepare model init
    model = get_vae(cfg.vae)
    model_kw = inspect.getfullargspec(model.__init__).args[1:]
    model_kw = {k: v for k, v in cfg.items() if k in model_kw}
    is_semi_supervised = ds.has_labels and model.is_semi_supervised()
    if is_semi_supervised:
        train = ds.create_dataset(partition='train',
                                  label_percent=0.1,
                                  **ds_kw)
        valid = ds.create_dataset(partition='valid',
                                  label_percent=1.0,
                                  **ds_kw)
    else:
        train = ds.create_dataset(partition='train', label_percent=0., **ds_kw)
        valid = ds.create_dataset(partition='valid', label_percent=0., **ds_kw)
    ### create the model
    vae = model(path=model_path,
                **get_networks(cfg.ds,
                               centerize_image=True,
                               is_semi_supervised=is_semi_supervised,
                               skip_generator=cfg.skip),
                **model_kw)
    vae.build((None, ) + ds.shape)
    vae.load_weights(raise_notfound=False, verbose=True)
    vae.early_stopping.mode = 'max'
    gym = create_gym(dsname=cfg.ds, vae=vae)

    ### fit the network
    def callback():
        metrics = vae.trainer.last_valid_metrics
        llk = metrics['llk_image'] if 'llk_image' in metrics else metrics[
            'llk_dense_type']
        vae.early_stopping.update(llk)
        signal = vae.early_stopping(verbose=True)
        if signal > 0:
            vae.save_weights(overwrite=True)
        # create the return metrics
        return dict(**gym.train()(prefix='train/', dpi=150),
                    **gym.valid()(prefix='valid/', dpi=150))

    ### evaluation
    if cfg.eval:
        vae.load_weights()
        gym.train()
        gym(save_path=gym_train_path, dpi=200, verbose=True)
        gym.valid()
        gym(save_path=gym_valid_path, dpi=200, verbose=True)
        gym.test()
        gym(save_path=gym_test_path, dpi=200, verbose=True)
    ### fit
    else:
        vae.early_stopping.patience = 10
        vae.fit(train,
                valid=valid,
                epochs=-1,
                clipnorm=100,
                valid_interval=30,
                logging_interval=2,
                skip_fitted=True,
                on_valid_end=callback,
                logdir=output_dir,
                compile_graph=True,
                track_gradients=True,
                **get_optimizer_info(cfg.ds))
        vae.early_stopping.plot_losses(
            path=os.path.join(output_dir, 'early_stopping.png'))
        vae.plot_learning_curves(
            os.path.join(output_dir, 'learning_curves.png'))