def main():
    # Set parameters
    vae_epoch = 2
    can_epoch = 1000
    batch_size = 64
    latent_dim = 10
    beta_eeg = 5.0
    train = True

    # Read data sets
    data_root = "/home/zainkhan/bci-representation-learning"
    eeg_train, eeg_test, pupil_train, pupil_test, sub_cond = utils.read_single_trial_datasets(
        data_root)

    if train:
        # Train VAE
        vae = VAE(beta=beta_eeg, latent_dim=latent_dim)
        vae.compile(optimizer=keras.optimizers.Adam())
        vae.fit(eeg_train, epochs=vae_epoch, batch_size=batch_size)

        # Save VAE
        #vae.encoder.save("vae_encoder")
        #vae.decoder.save("vae_decoder")

        # Train CAN
        can = CAN(
            vae=vae,
            can_data=pupil_train,
            vae_data=eeg_train,
            latent_dim=latent_dim,
            epochs=can_epoch,
            batch_size=batch_size,
        )
        can.compile(optimizer=keras.optimizers.Adam(), run_eagerly=True)
        can.fit(pupil_train,
                epochs=can_epoch,
                batch_size=batch_size,
                shuffle=False)

        # Save CAN
        can.encoder.save("can_encoder")
        can.decoder.save("can_decoder")
    else:
        # Load all encoders/decoders
        vae = VAE(beta=beta_eeg, latent_dim=latent_dim)
        vae.encoder = keras.models.load_model("vae_encoder")
        vae.decoder = keras.models.load_model("vae_decoder")

        can = CAN(vae=vae, vae_data=eeg_train, latent_dim=latent_dim)
        can.encoder = keras.models.load_model("can_encoder")
        can.decoder = keras.models.load_model("can_decoder")

    # VAE predictions
    encoded_data = vae.encoder.predict(eeg_test)
    decoded_data = vae.decoder.predict(encoded_data)
    fn = utils.get_filename("predictions/", "test-eeg")
Esempio n. 2
0
    while real_imgs is None or real_imgs.size(0) < num_samples:
        imgs = next(real_dl)
        if real_imgs is None:
            real_imgs = imgs[0]
        else:
            real_imgs = torch.cat((real_imgs, imgs[0]), 0)
    real_imgs = real_imgs[:num_samples].expand(-1, 3, -1, -1)

    with torch.no_grad():
        samples = None
        while samples is None or samples.size(0) < num_samples:
            imgs = model.forward(
                z=torch.randn(args.batch_size, model.latent_dim).to(device))
            if samples is None:
                samples = imgs
            else:
                samples = torch.cat((samples, imgs), 0)
    samples = samples[:num_samples].expand(-1, 3, -1, -1)
    samples = samples.cpu()

    fid = fid_score.calculate_fid_given_images(real_imgs, samples,
                                               args.batch_size, device)
    tb_writer.add_scalar('fid', fid)
    print("FID score: {:.3f}".format(fid), flush=True)

    imgs = next(real_dl)[0][0:4].to(device)
    noise = torch.randn(4, args.latent_dim).to(device)
    mu, log_var = model.encoder(imgs).chunk(2, dim=-1)
    sampled_z = model.reparameterize(mu, log_var).unsqueeze(1)
    interpolate(model, sqrt_K=14, noises=sampled_z)