Пример #1
0
def main(cfg: dict):
    assert cfg.vae is not None, 'model is not provided in the configuration'
    outdir = get_output_dir()
    # load dataset
    ds = MNIST()
    shape = ds.shape
    train = ds.create_dataset(partition='train',
                              batch_size=batch_size,
                              shuffle=2000,
                              drop_remainder=True)
    valid = ds.create_dataset(partition='valid', batch_size=batch_size)
    x_test, y_test = ds.numpy(partition='test',
                              shuffle=None,
                              label_percent=True)
    y_test = ds.labels[np.argmax(y_test, axis=-1)]
    # create the model
    vae_class = vi.get_vae(cfg.vae)
    for i, (posterior, prior,
            activation) in enumerate(product(*posteriors_info)):
        name = f"{posterior}_{prior.name}_{activation}"
        path = os.path.join(outdir, name)
        if not os.path.exists(path):
            os.makedirs(path)
        model_path = os.path.join(path, 'model')
        vae = vae_class(encoder=encoder.create_network(),
                        decoder=decoder.create_network(),
                        observation=vi.RVconf(shape,
                                              'bernoulli',
                                              projection=True,
                                              name='Image'),
                        latents=vi.RVconf(
                            encoded_size,
                            posterior,
                            projection=True,
                            prior=prior,
                            kwargs=dict(scale_activation=activation),
                            name='Latents'),
                        analytic=False,
                        path=model_path,
                        name=name)
        vae.build((None, ) + shape)
        vae.load_weights()
        vae.fit(train=train,
                valid=valid,
                max_iter=max_iter,
                valid_freq=1000,
                compile_graph=True,
                skip_fitted=True,
                on_valid_end=partial(callback, vae=vae, x=x_test, y=y_test),
                logdir=path,
                track_gradients=True).save_weights()
Пример #2
0
def main(cfg: dict):
    assert cfg.vae is not None and cfg.beta is not None, \
      f'Invalid arguments: {cfg}'
    if cfg.ds == 'mnist':
        ds = MNIST()
    elif cfg.ds == 'fmnist':
        ds = FashionMNIST()
    else:
        raise NotImplementedError(f'No support for dataset with name={cfg.ds}')
    input_shape = ds.shape
    train = ds.create_dataset(partition='train', batch_size=batch_size)
    valid = ds.create_dataset(partition='valid', batch_size=batch_size)
    x_test, y_test = ds.numpy(partition='test',
                              batch_size=batch_size,
                              shuffle=1000,
                              inc_labels=1.0)
    y_test = ds.labels[np.argmax(y_test, axis=-1)]
    ## create the prior and the network
    pz = tfd.Sample(tfd.Normal(loc=0, scale=1), sample_shape=encoded_size)
    z_samples = pz.sample(16)
    encoder = create_encoder(input_shape)
    decoder = create_decoder()
    ## create the model
    # tfp model API
    if cfg.vae == 'tfp':
        encoder.append(tfpl.MultivariateNormalTriL(encoded_size))
        encoder = keras.Sequential(encoder, name='Encoder')
        decoder.append(tfpl.IndependentBernoulli(input_shape))
        decoder = keras.Sequential(decoder, name="Decoder")
        vae = keras.Model(
            inputs=encoder.inputs,
            outputs=[decoder(encoder.outputs[0]), encoder.outputs[0]],
            name='VAE')
    # odin model API
    else:
        encoder = keras.Sequential(encoder, name='Encoder')
        decoder = keras.Sequential(decoder, name="Decoder")
        vae = get_vae(cfg.vae)(
            encoder=encoder,
            decoder=decoder,
            # latents=tfpl.MultivariateNormalTriL(encoded_size),
            latents=RVmeta(event_shape=(encoded_size, ),
                           posterior='mvntril',
                           projection=False,
                           name="Latent"),
            observation=RVmeta(event_shape=input_shape,
                               posterior="bernoulli",
                               projection=False,
                               name="Image"),
        )
    ### training the model
    vae.build(input_shape=(None, ) + input_shape)
    params = vae.trainable_variables
    opt = tf.optimizers.Adam(learning_rate=1e-3)

    def optimize(x, training=None):
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            if training:
                tape.watch(params)
            px, qz = vae(x, training=training)
            z = qz._value()
            kl = tf.reduce_mean(qz.log_prob(z) - pz.log_prob(z), axis=-1)
            nll = -tf.reduce_mean(px.log_prob(x), axis=-1)
            loss = nll + cfg.beta * kl
            if training:
                grads = tape.gradient(loss, params)
                grads_params = [(g, p) for g, p in zip(grads, params)
                                if g is not None]
                opt.apply_gradients(grads_params)
                grads = {
                    f'_grad/{p.name}': tf.linalg.norm(g)
                    for p, g in grads_params
                }
            else:
                grads = dict()
        return loss, dict(nll=nll, kl=kl, **grads)

    def callback():
        trainer = get_current_trainer()
        x, y = x_test[:1000], y_test[:1000]
        px, qz = vae(x, training=False)
        # latents
        qz_mean = tf.reduce_mean(qz.mean(), axis=0)
        qz_std = tf.reduce_mean(qz.stddev(), axis=0)
        w = tf.reduce_sum(decoder.trainable_variables[0], axis=(0, 1, 2))
        # plot the latents and its weights
        fig = plt.figure(figsize=(6, 4), dpi=200)
        ax = plt.gca()
        l1 = ax.plot(qz_mean,
                     label='mean',
                     linewidth=1.0,
                     linestyle='--',
                     marker='o',
                     markersize=4,
                     color='r',
                     alpha=0.5)
        l2 = ax.plot(qz_std,
                     label='std',
                     linewidth=1.0,
                     linestyle='--',
                     marker='o',
                     markersize=4,
                     color='g',
                     alpha=0.5)
        ax1 = ax.twinx()
        l3 = ax1.plot(w,
                      label='weight',
                      linewidth=1.0,
                      linestyle='--',
                      marker='o',
                      markersize=4,
                      color='b',
                      alpha=0.5)
        lines = l1 + l2 + l3
        labs = [l.get_label() for l in lines]
        ax.grid(True)
        ax.legend(lines, labs)
        img_qz = vs.plot_to_image(fig)
        # reconstruction
        fig = plt.figure(figsize=(5, 5), dpi=120)
        vs.plot_images(np.squeeze(px.mean().numpy()[:25], axis=-1),
                       grids=(5, 5))
        img_res = vs.plot_to_image(fig)
        # latents
        fig = plt.figure(figsize=(5, 5), dpi=200)
        z = fast_umap(qz.mean().numpy())
        vs.plot_scatter(z, color=y, size=12.0, alpha=0.4)
        img_umap = vs.plot_to_image(fig)
        # gradients
        grads = [(k, v) for k, v in trainer.last_metrics.items()
                 if '_grad/' in k]
        encoder_grad = sum(v for k, v in grads if 'Encoder' in k)
        decoder_grad = sum(v for k, v in grads if 'Decoder' in k)
        return dict(reconstruct=img_res,
                    umap=img_umap,
                    latents=img_qz,
                    qz_mean=qz_mean,
                    qz_std=qz_std,
                    w_decoder=w,
                    llk_test=tf.reduce_mean(px.log_prob(x)),
                    encoder_grad=encoder_grad,
                    decoder_grad=decoder_grad)

    trainer = Trainer(logdir=get_output_dir())
    trainer.fit(train_ds=train.repeat(-1),
                optimize=optimize,
                valid_ds=valid,
                max_iter=max_iter,
                compile_graph=True,
                log_tag=f'{cfg.vae}_{cfg.beta}',
                callback=callback,
                valid_freq=1000)