parser.add_argument("--loadpth", type=str, default='./results/vqvae_data_bo.pth') parser.add_argument("--data_dir", type=str, default='/home/karam/Downloads/bco/') parser.add_argument("--data", type=str, default='bco') args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #Load model model = VQVAE(args.n_hiddens, args.n_residual_hiddens, args.n_residual_layers, args.n_embeddings, args.embedding_dim, args.beta).to(device) assert args.loadpth is not '' model.load_state_dict(torch.load(args.loadpth)['model']) model.eval() print("Loaded model") #Load data save_dir = os.getcwd() + '/data' data_dir = args.data_dir if args.data == 'bco': data1 = np.load(data_dir + "/bcov5_0.npy") data2 = np.load(data_dir + "/bcov5_1.npy") data3 = np.load(data_dir + "/bcov5_2.npy") data4 = np.load(data_dir + "/bcov5_3.npy") data = np.concatenate((data1, data2, data3, data4), axis=0) elif args.data == 'bo': data = np.load(data_dir + "/bo-150-50-20.npy", allow_pickle=True) n_trajs, length = data.shape[:2]
n_trajs, length = data.shape[:2] img_dim=args.img_dim model = GatedPixelCNN(n_embeddings=args.n_embeddings, imgximg=args.img_dim**2, n_layers=args.n_layers, conditional=args.conditional, x_one_hot=args.x_one_hot,c_one_hot=args.c_one_hot, n_cond_res_block=args.n_cres_layers).to(device) model.train() criterion = nn.CrossEntropyLoss().cuda() opt = torch.optim.Adam(model.parameters(), lr=args.learning_rate) if args.loadpth_vq is not '': vae = VQVAE(args.n_hiddens, args.n_residual_hiddens, args.n_residual_layers, args.n_embeddings, args.embedding_dim, args.beta).cuda() vae.load_state_dict(torch.load(args.loadpth_vq)['model']) print("VQ Loaded") vae.eval() if args.data=='bco': sample_c=vae(sample_c_imgs,latent_only=True).detach().cpu().numpy().reshape(-1,length).squeeze() # if args.loadpth_pcnn is not '': model.load_state_dict(torch.load(args.loadpth_pcnn)) print("PCNN Loaded") n_trajs = len(data) dt = n_trajs // context.shape[0] n_batch = int(n_trajs / args.batch_size) n_trajs_t = len(val) dv = n_trajs_t // valcon.shape[0] n_batch_t = int(n_trajs_t / args.batch_size)