Beispiel #1
0
def load_model(
    model_type,
    max_filters,
    num_layers,
    image_size,
    model_parameters,
    small_conv,
    model_path,
    device,
):
    if model_type == "vae":
        model = vae.ConvolutionalVAE(
            max_filters=max_filters,
            num_layers=num_layers,
            input_image_dimensions=image_size,
            latent_dim=model_parameters,
            small_conv=small_conv,
        )
    elif model_type == "dual_input_vae":
        model = vae.FusionVAE(
            max_filters=max_filters,
            num_layers=num_layers,
            input_image_dimensions=image_size,
            latent_dim=model_parameters,
            small_conv=small_conv,
        )
    elif model_type == "vq_vae":
        num_embeddings = model_parameters["K"]
        embedding_dim = model_parameters["D"]
        commitment_cost = model_parameters["commitment_cost"]
        model = vqvae.VQVAE(
            num_layers=num_layers,
            input_image_dimensions=image_size,
            small_conv=small_conv,
            embedding_dim=embedding_dim,
            num_embeddings=num_embeddings,
            commitment_cost=commitment_cost,
        )
    elif model_type == "dual_input_autoencoder":
        model = autoencoder.FusionAE(
            max_filters=max_filters,
            num_layers=num_layers,
            input_image_dimensions=image_size,
            latent_dim=model_parameters,
            small_conv=small_conv,
        )
    else:
        model = autoencoder.ConvolutionalAE(
            max_filters=max_filters,
            num_layers=num_layers,
            input_image_dimensions=image_size,
            latent_dim=model_parameters,
            small_conv=small_conv,
        )
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    return model
################################################################################
#################################### Setup $####################################
################################################################################

# Setup Device
gpu = torch.cuda.is_available()
device = torch.device("cuda" if gpu else "cpu")
print(gpu, device)

# Create Output Paths
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Create & Load VQVAE Model
vq_vae = vqvae.VQVAE(**vq_vae_model_config)
vq_vae.load_state_dict(torch.load(vq_vae_model_path, map_location=device))
vq_vae.eval()
vq_vae.to(device)

checkpoint = torch.load(model_path, map_location=device)
label_handler = data.ConditioningLabelsHandlerFromSaved(
    conditioning_info_file, conditioning_info_columns,
    checkpoint["encoding_dict"])
model_config["conditioning_size"] = label_handler.get_size()

# Create Model
model = gated_pixelcnn.ConditionalPixelCNN(**model_config)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval()
    prior_input_dim = image_size
    prior_output_dim = prior_input_dim // np.power(2, vq_vae_num_layers)

    ## Setup Devices
    gpu = torch.cuda.is_available()
    device = torch.device("cuda" if gpu else "cpu")

    ## Setup Transform
    transform = data.image2tensor_resize(image_size)

    ## Load VQVAE Model
    model = vqvae.VQVAE(
        num_layers=vq_vae_num_layers,
        input_image_dimensions=image_size,
        small_conv=vq_vae_small_conv,
        embedding_dim=vq_vae_embedding_dim,
        num_embeddings=vq_vae_num_embeddings,
        commitment_cost=vq_vae_commitment_cost,
        use_max_filters=vq_vae_use_max_filters,
        max_filters=vq_vae_max_filters,
    )
    model_path = os.path.join(model_prefix, vq_vae_model_name, "model.pt")
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    ## Load Prior
    prior = cnn_prior.CNNPrior(
        input_channels=prior_input_channels,
        output_channels=prior_output_channels,
        input_dim=prior_input_dim,
        output_dim=prior_output_dim,
    )