Пример #1
0
def main(spec, num_samples, pool):
    checkpoint_dir = os.path.join(CHECKPOINT_ROOT, spec)
    model_type, model_args, dataset_names = spec_util.parse_setup_spec(spec)
    if model_type == 'VAE':
        model = vae.VAE(model_args)
        trainer = vae.Trainer(model, beta=4.)
        trainer.cuda()
        models.load_checkpoint(trainer, checkpoint_dir)
        model.eval()
        sample_latent = model.sample_latent(num_samples)
        sample_imgs = model.dec(sample_latent)
    elif model_type in ['GAN', 'GANmc']:
        model = gan.GAN(model_args)
        trainer = gan.Trainer(model)
        trainer.cuda()
        models.load_checkpoint(trainer, checkpoint_dir)
        model.eval()
        sample_imgs = model(num_samples)
    else:
        raise ValueError(f"Invalid model type: {model_type}")

    print(f"Loaded model {checkpoint_dir}. Measuring samples...")
    sample_imgs_np = sample_imgs.detach().cpu().squeeze().numpy()
    sample_metrics = measure.measure_batch(sample_imgs_np, pool=pool)

    os.makedirs(METRICS_ROOT, exist_ok=True)
    metrics_path = os.path.join(METRICS_ROOT, f"{spec}_metrics.csv")
    sample_metrics.to_csv(metrics_path, index_label='index')
    print(f"Morphometrics saved to {metrics_path}")
Пример #2
0
def load_vae(path, device=None):
    with open(os.path.join(path, 'params.yaml')) as f:
        vae_args = yaml.load(f)

    if vae_args['kernel_dim'] == 5:
        decoder = vae_mod.Decoder5x5(vae_args['z_dim'],
                                     vae_args['hidden_dim'],
                                     var=vae_args['var'])
        encoder = vae_mod.Encoder5x5(vae_args['z_dim'], vae_args['hidden_dim'])
    if vae_args['kernel_dim'] == 7:
        decoder = vae_mod.Decoder7x7(vae_args['z_dim'],
                                     vae_args['hidden_dim'],
                                     var=vae_args['var'])
        encoder = vae_mod.Encoder7x7(vae_args['z_dim'], vae_args['hidden_dim'])
    elif vae_args['kernel_dim'] == 16:
        decoder = vae_mod.Decoder16x16(vae_args['z_dim'],
                                       vae_args['hidden_dim'],
                                       var=vae_args['var'])
        encoder = vae_mod.Encoder16x16(vae_args['z_dim'],
                                       vae_args['hidden_dim'])

    vae = vae_mod.VAE(encoder, decoder, device=device)
    vae.load_state_dict(torch.load(os.path.join(path, 'vae_params.torch')))

    return vae
Пример #3
0
def main(use_cuda: bool, data_dirs: Union[str, Sequence[str]], weights: Optional[Sequence[Number]],
         ckpt_root: str, latent_dim: int, num_epochs: int,
         batch_size: int, save: bool, resume: bool, plot: bool):
    device = torch.device('cuda' if use_cuda else 'cpu')

    if isinstance(data_dirs, str):
        data_dirs = [data_dirs]
    dataset_names = [os.path.split(data_dir)[-1] for data_dir in data_dirs]
    ckpt_name = spec_util.format_setup_spec('VAE', latent_dim, dataset_names)
    print(f"Training {ckpt_name}...")
    ckpt_dir = None if ckpt_root is None else os.path.join(ckpt_root, ckpt_name)

    train_set = data_util.get_dataset(data_dirs, weights, train=True)
    test_set = data_util.get_dataset(data_dirs, weights, train=False)

    test_batch_size = 32
    dl_kwargs = dict(num_workers=1, pin_memory=True) if use_cuda else {}
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, **dl_kwargs)
    test_loader = DataLoader(test_set, batch_size=test_batch_size, shuffle=True, **dl_kwargs)
    num_batches = len(train_loader.dataset) // train_loader.batch_size

    model = vae.VAE(latent_dim)
    trainer = vae.Trainer(model, beta=4.)
    trainer.to(device)

    test_iterator = iter(test_loader)

    start_epoch = -1
    if resume:
        try:
            start_epoch = load_checkpoint(trainer, ckpt_dir)
            if plot:
                test(model, next(test_iterator)[0])
        except ValueError:
            print(f"No checkpoint to resume from in {ckpt_dir}")
        except FileNotFoundError:
            print(f"Invalid checkpoint directory: {ckpt_dir}")
    elif save:
        if os.path.exists(ckpt_dir):
            print(f"Clearing existing checkpoints in {ckpt_dir}")
            for filename in os.listdir(ckpt_dir):
                os.remove(os.path.join(ckpt_dir, filename))

    for epoch in range(start_epoch + 1, num_epochs):
        trainer.train()
        for batch_idx, (data, _) in enumerate(train_loader):
            verbose = batch_idx % 10 == 0
            if verbose:
                print(f"[{epoch}/{num_epochs}: {batch_idx:3d}/{num_batches:3d}] ", end='')

            real_data = data.to(device).unsqueeze(1).float() / 255.
            trainer.step(real_data, verbose)

        if save:
            save_checkpoint(trainer, ckpt_dir, epoch)

        if plot:
            test(model, next(test_iterator)[0])
Пример #4
0
    def __init__(self, image_width: int = 64, image_height: int = 64, latent_dim: int = 32, time_steps: int = 10):
        self.image_height = image_height
        self.image_width = image_width
        self.latent_dim = latent_dim
        self.time_steps = time_steps

        self.image_input_shape = (image_height, image_width, 3)

        self.model_vae = vae.VAE(self.image_input_shape, self.latent_dim)
        self.vae_train_config_dict = None

        self.model_rnn = rnn.RNN(self.latent_dim, self.latent_dim, self.time_steps, contain_mdn_layer=True)
        self.rnn_train_config_dict = None
Пример #5
0
def create_model(config, gen_bias_init=0.0, data_type='binary'):
    """Creates a tf.keras.Model object.

    Args:
        config: A configuration object with config values accessible as properties.
            Most likely a FLAGS object.
        gen_bias_init: Bias initialisation of generative model. Usually set to 
            something sensible like the mean of the training set.
        data_type: Whether modelling real-valued or binary data.

    Returns:
        model: The constructed deep generative model.
    """

    if config.model == 'discvae':
        # Create a discvae.DiSCVAE object
        model = discvae.DiSCVAE(static_latent_size=config.latent_size,
                                dynamic_latent_size=config.dynamic_latent_size,
                                mix_components=config.mixture_components,
                                hidden_size=config.hidden_size,
                                rnn_size=config.rnn_size,
                                num_channels=config.num_channels,
                                decoder_type=data_type,
                                encoded_data_size=config.encoded_data_size,
                                encoded_latent_size=config.encoded_latent_size,
                                sigma_min=1e-5,
                                raw_sigma_bias=0.5,
                                gen_bias_init=gen_bias_init,
                                beta=config.beta,
                                temperature=config.init_temp)
    elif config.model == 'vrnn':
        # Create a vrnn.VRNN object
        model = vrnn.VRNN(latent_size=config.latent_size,
                          hidden_size=config.hidden_size,
                          rnn_size=config.rnn_size,
                          num_channels=config.num_channels,
                          decoder_type=data_type,
                          encoded_data_size=config.encoded_data_size,
                          encoded_latent_size=config.encoded_latent_size,
                          sigma_min=1e-5,
                          raw_sigma_bias=0.5,
                          gen_bias_init=gen_bias_init)
    elif config.model == 'gmvae':
        # Create a gmvae.GMVAE object
        model = gmvae.GMVAE(latent_size=config.latent_size,
                            mix_components=config.mixture_components,
                            hidden_size=config.hidden_size,
                            num_channels=config.num_channels,
                            decoder_type=data_type,
                            encoded_data_size=config.encoded_data_size,
                            sigma_min=1e-5,
                            raw_sigma_bias=0.5,
                            temperature=config.init_temp)
    elif config.model == 'vae':
        # Create a vae.VAE object
        model = vae.VAE(latent_size=config.latent_size,
                        hidden_size=config.hidden_size,
                        num_channels=config.num_channels,
                        decoder_type=data_type,
                        encoded_data_size=config.encoded_data_size,
                        sigma_min=1e-5,
                        raw_sigma_bias=0.5)
    else:
        logging.error("No tf.keras.Model available by the name {}".format(
            config.model))

    return model
Пример #6
0
        add = np.zeros(shape=z.shape)
        for i in range(-3, 3):
            add[0] = i
            z_new = z + add
            self.interpolated_latents.append(z_new)


if __name__ == '__main__':
    image_size = 96
    seed = 100
    USE_CUDA = torch.cuda.is_available()
    generative_model = vae.VAE(conv_layers=32,
                               z_dimension=64,
                               pool_kernel_size=2,
                               conv_kernel_size=4,
                               input_channels=3,
                               height=96,
                               width=96,
                               hidden_dim=128,
                               use_cuda=USE_CUDA)
    denoising_autoencoder = vae.DAE(conv_layers=16,
                                    conv_kernel_size=3,
                                    pool_kernel_size=2,
                                    height=96,
                                    width=96,
                                    input_channels=3,
                                    hidden_dim=64,
                                    noise_scale=0.3,
                                    use_cuda=USE_CUDA)

    if USE_CUDA:
Пример #7
0
        z = self.latents[0]
        self.interpolated_latents = []
        add = np.zeros(shape=z.shape)
        for i in range(-3, 3):
            add[0] = i
            z_new = z + add
            self.interpolated_latents.append(z_new)


if __name__ == '__main__':
    image_size = 96
    seed = 100
    generative_model = vae.VAE(conv_layers=16,
                               z_dimension=32,
                               pool_kernel_size=2,
                               conv_kernel_size=3,
                               input_channels=3,
                               height=96,
                               width=96,
                               hidden_dim=64)
    denoising_autoencoder = vae.DAE(conv_layers=16,
                                    conv_kernel_size=3,
                                    pool_kernel_size=2,
                                    height=96,
                                    width=96,
                                    input_channels=3,
                                    hidden_dim=64,
                                    noise_scale=0.3)
    trainer = Trainer(beta=1,
                      generative_model=generative_model,
                      learning_rate=1e-2,
                      num_epochs=30,
Пример #8
0
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if not os.path.exists(args.log_dir):
    os.makedirs(args.log_dir)

data = np.load(args.data)
D = data.shape[-1]

if D == 3:
    decoder = vae_module.Decoder3x3(args.z_dim, args.hidden_dim)
    encoder = vae_module.Encoder3x3(args.z_dim, args.hidden_dim)
elif D == 5:
    decoder = vae_module.Decoder5x5(args.z_dim, args.hidden_dim)
    encoder = vae_module.Encoder5x5(args.z_dim, args.hidden_dim)

vae = vae_module.VAE(encoder, decoder, device=device)
vae.load_state_dict(torch.load(args.model))

# Reconstructions
n, m = 10, 5
np.random.seed(42)
fig, axes = plt.subplots(figsize=(15, 12), nrows=n, ncols=m)
img_idxs = np.random.randint(0, len(data), size=(n * m))

for i, ax in enumerate(axes.flat):
    c = data[img_idxs[i]][np.newaxis]
    inp = torch.FloatTensor(c).to(device)
    _, [rec, _] = vae(inp)
    rec = tonp(rec)
    c = np.concatenate([c, rec], 3)[0, 0]
    sns.heatmap(c, ax=ax)
Пример #9
0
                                             shuffle=False)

    assert args.kernel_dim == D, '--kernel-dim != D (in dataset)'

    if D == 3:
        decoder = vae.Decoder3x3(args.z_dim, args.hidden_dim)
        encoder = vae.Encoder3x3(args.z_dim, args.hidden_dim)
    elif D == 5:
        decoder = vae.Decoder5x5(args.z_dim, args.hidden_dim, var=args.var)
        encoder = vae.Encoder5x5(args.z_dim, args.hidden_dim)
    elif D == 7:
        decoder = vae.Decoder7x7(args.z_dim, args.hidden_dim, var=args.var)
        encoder = vae.Encoder7x7(args.z_dim, args.hidden_dim)
    elif D == 16:
        decoder = vae.Decoder16x16(args.z_dim, args.hidden_dim, var=args.var)
        encoder = vae.Encoder16x16(args.z_dim, args.hidden_dim)

    vae = vae.VAE(encoder, decoder, device=device)
    if args.resume_vae:
        vae.load_state_dict(torch.load(args.resume_vae))
    optimizer = torch.optim.Adam(vae.parameters(), lr=args.lr)
    if args.resume_opt:
        optimizer.load_state_dict(torch.load(args.resume_opt))
        optimizer.param_groups[0]['lr'] = args.lr
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_decay_step,
                                                args.decay)
    criterion = utils.VAEELBOLoss(use_cuda=args.cuda)

    train(trainloader, testloader, vae, optimizer, scheduler, criterion, args,
          D)