def vae_generation(): target_shape = (64, 64) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # model_path = "useful_models/vae_5_layers_mnist_max_lr_0.01_24122020_115212_050.tar" # model_path = "useful_models/vae_5_layers_mnist_max_lr_0.01_24122020_141519_050.tar" model_path = "useful_models/squeeze_vae_mountain_car_v0_max_lr_0.005_29122020_143138_050.tar" input_shape = (1, ) + target_shape # model = AutoEncoder.get_basic_ae(input_shape=input_shape).to(device) # model = VariationalAutoEncoder.get_basic_vae(input_shape=input_shape).to(device) model = VariationalAutoEncoder.get_squeeze_vae( input_shape=input_shape).to(device) load_checkpoint(model_path, model) mnist_preprocessor = CNNPreProcessor(bgr_mean=0.1307, bgr_std=0.3081, target_shape=target_shape) mountain_car_64g_preprocessor = CNNPreProcessor(bgr_mean=0.9857, bgr_std=0.1056) preprocessor = mountain_car_64g_preprocessor for _ in range(100): generated = model.generate(device) im_generated = preprocessor.reverse_preprocess(generated) opencv_show(im_generated) print("")
def evaluation(): # cart_pole_v0_bgr_mean = (0.9890, 0.9898, 0.9908) # cart_pole_v0_bgr_std = (0.0977, 0.0936, 0.0906) # target_shape = (64, 64) cart_pole_v0_bgr_mean = (0.9922, 0.9931, 0.9940) cart_pole_v0_bgr_std = (0.0791, 0.0741, 0.0703) target_shape = (64, 64) latent_dim = 512 preprocessor = CNNPreProcessor(bgr_mean=cart_pole_v0_bgr_mean, bgr_std=cart_pole_v0_bgr_std, target_shape=target_shape) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # model_path = "model_checkpoints/cartpoleV0_autoencoder_2_200.tar" # model_path = "model_checkpoints/cartpoleV0_autoencoder_3_032.tar" # model_path = "useful_models/cartpoleV0_basic_autoencoder_mnist_14122020_026.tar" # model_path = "useful_models/no_out_activation_autoencoder_5_layers_mnist_max_lr_0.01_23122020_160428_049.tar" # model_path = "useful_models/squeeze_autoencoder_mnist_max_lr_0.001_26122020_122631_050.tar" model_path = "useful_models/squeeze_vae_mnist_max_lr_0.001_26122020_142330_050.tar" # model_path = "useful_models/vae_5_layers_mnist_max_lr_0.01_24122020_115212_050.tar" input_shape = (1, ) + target_shape # model = AutoEncoder.get_basic_ae(input_shape=input_shape).to(device) # model = AutoEncoder.get_squeeze_ae(input_shape=input_shape).to(device) # model = VariationalAutoEncoder.get_basic_vae(input_shape=input_shape).to(device) model = VariationalAutoEncoder.get_squeeze_vae( input_shape=input_shape).to(device) load_checkpoint(model_path, model) ims_path = "agent_frames/cartpoleV0" f_names = [ join(ims_path, e) for e in listdir(ims_path) if e.endswith(".jpg") ] mnist_test_ims = [el for el in mnist.test_images()] mnist_preprocessor = CNNPreProcessor(bgr_mean=0.1307, bgr_std=0.3081, target_shape=target_shape) preprocessor = mnist_preprocessor for _ in range(100): # random_im_path = np.random.choice(f_names) # # im = cv.imread(random_im_path) im = mnist_test_ims[np.random.randint(len(mnist_test_ims))] im_target = cv.resize(im, target_shape) orig_shape = im.shape[:2][::-1] in_t = preprocessor.preprocess(im).to(device) # out_t, embedding = model(in_t) out_t, mu, log_var = model(in_t) loss = F.mse_loss(in_t, out_t) out_im_target = preprocessor.reverse_preprocess(out_t) out_im = cv.resize(out_im_target, orig_shape) sbs = put_side_by_side([im_target, out_im_target]) print(f"Loss: {loss.item()}") opencv_show(sbs) print("") pass