def load_model( model_type, max_filters, num_layers, image_size, model_parameters, small_conv, model_path, device, ): if model_type == "vae": model = vae.ConvolutionalVAE( max_filters=max_filters, num_layers=num_layers, input_image_dimensions=image_size, latent_dim=model_parameters, small_conv=small_conv, ) elif model_type == "dual_input_vae": model = vae.FusionVAE( max_filters=max_filters, num_layers=num_layers, input_image_dimensions=image_size, latent_dim=model_parameters, small_conv=small_conv, ) elif model_type == "vq_vae": num_embeddings = model_parameters["K"] embedding_dim = model_parameters["D"] commitment_cost = model_parameters["commitment_cost"] model = vqvae.VQVAE( num_layers=num_layers, input_image_dimensions=image_size, small_conv=small_conv, embedding_dim=embedding_dim, num_embeddings=num_embeddings, commitment_cost=commitment_cost, ) elif model_type == "dual_input_autoencoder": model = autoencoder.FusionAE( max_filters=max_filters, num_layers=num_layers, input_image_dimensions=image_size, latent_dim=model_parameters, small_conv=small_conv, ) else: model = autoencoder.ConvolutionalAE( max_filters=max_filters, num_layers=num_layers, input_image_dimensions=image_size, latent_dim=model_parameters, small_conv=small_conv, ) model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() return model
################################################################################ #################################### Setup $#################################### ################################################################################ # Setup Device gpu = torch.cuda.is_available() device = torch.device("cuda" if gpu else "cpu") print(gpu, device) # Create Output Paths if not os.path.exists(output_dir): os.makedirs(output_dir) # Create & Load VQVAE Model vq_vae = vqvae.VQVAE(**vq_vae_model_config) vq_vae.load_state_dict(torch.load(vq_vae_model_path, map_location=device)) vq_vae.eval() vq_vae.to(device) checkpoint = torch.load(model_path, map_location=device) label_handler = data.ConditioningLabelsHandlerFromSaved( conditioning_info_file, conditioning_info_columns, checkpoint["encoding_dict"]) model_config["conditioning_size"] = label_handler.get_size() # Create Model model = gated_pixelcnn.ConditionalPixelCNN(**model_config) model.load_state_dict(checkpoint["model_state_dict"]) model.to(device) model.eval()
prior_input_dim = image_size prior_output_dim = prior_input_dim // np.power(2, vq_vae_num_layers) ## Setup Devices gpu = torch.cuda.is_available() device = torch.device("cuda" if gpu else "cpu") ## Setup Transform transform = data.image2tensor_resize(image_size) ## Load VQVAE Model model = vqvae.VQVAE( num_layers=vq_vae_num_layers, input_image_dimensions=image_size, small_conv=vq_vae_small_conv, embedding_dim=vq_vae_embedding_dim, num_embeddings=vq_vae_num_embeddings, commitment_cost=vq_vae_commitment_cost, use_max_filters=vq_vae_use_max_filters, max_filters=vq_vae_max_filters, ) model_path = os.path.join(model_prefix, vq_vae_model_name, "model.pt") model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() ## Load Prior prior = cnn_prior.CNNPrior( input_channels=prior_input_channels, output_channels=prior_output_channels, input_dim=prior_input_dim, output_dim=prior_output_dim, )