def updatefig(*args): nonlocal done nonlocal obs nonlocal HORIZON nonlocal timestep rebuild, obs_torch = transform_observation(obs) if not done and timestep < HORIZON: #action, action_proba = vape.act(obs_torch) #action = action[0].detach().numpy() action = env.action_space.sample() action = [action[0], 0.3, 0.0] obs, reward, done, info = env.step(action) #env.render(mode='human') timestep += 1 else: done = False obs = env.reset() timestep = 0 c = NHWC(vape.centroids.detach().numpy()) im.set_array( side_by_side( side_by_side(side_by_side(obs, rebuild), c[0, :, :, 0]), c[1, :, :, 0])) vape.optimize_vae(obs_torch, optimizer) time.sleep(0.01) return im,
def show_centroids(): c = NHWC(vape.centroids.detach().numpy()) fig, ax = plt.subplots(nrows=2, ncols=2) ax[0, 0].imshow(c[0, :, :, 0], cmap='Greys') ax[0, 1].imshow(c[1, :, :, 0], cmap='Greys') ax[1, 0].imshow(c[2, :, :, 0], cmap='Greys') ax[1, 1].imshow(c[3, :, :, 0], cmap='Greys') plt.show()
def updatefig(*args): global done global obs if not done: obs_torch = torch.from_numpy(NCHW([obs])).float().to(device) mu, _ = vape.encode(obs_torch) action, logp = vape.act(obs_torch) obs, reward, done, info = env.step(action) env.render(mode='human') else: done = False obs = env.reset() obs_torch = torch.from_numpy(NCHW([obs])).float().to(device) rebuild, mu, log_sigma, z = vape.encode_decode(obs_torch) rebuild = NHWC(rebuild.detach().numpy()[0]) im.set_array(side_by_side(obs, rebuild)) time.sleep(0.01) return im,
def updatefig(*args): nonlocal done nonlocal obs nonlocal HORIZON nonlocal timestep obs_torch = torch.from_numpy(NCHW([obs])).float().to(device) if not done and timestep < HORIZON: action, action_proba = vape.act(obs_torch) action = action[0].detach().numpy() obs, reward, done, info = env.step(action) env.render(mode='human') timestep += 1 else: done = False obs = env.reset() timestep = 0 rebuild = vape.encode_decode(obs_torch) rebuild = NHWC(rebuild.detach().numpy()[0]) im.set_array(side_by_side(obs, rebuild)) vape.optimize_vae(obs_torch, optimizer) time.sleep(0.01) return im,
def updatefig(*args): global done global obs if not done: if policy: obs_torch = torch.from_numpy(NCHW([obs])).float().to(device) mu, _ = vae.encode(obs_torch) action, action_proba = policy.act(mu.detach().numpy()) action = action[0] else: action = env.action_space.sample() action = [action[0], 0.3, 0.0] obs, reward, done, info = env.step(action) env.render(mode='human') else: done = False obs = env.reset() obs_torch = torch.from_numpy(NCHW([obs])).float().to(device) mu, log_sigma, z, rebuild = vae(obs_torch) rebuild = NHWC(rebuild.detach().numpy()[0]) im.set_array(side_by_side(obs, rebuild)) time.sleep(0.01) return im,
def main(): # Parse arguments parser = argparse.ArgumentParser(description='REINFORCE using PyTorch') # Logging parser.add_argument('--alias', type=str, default='base', help="""Alias of the model.""") parser.add_argument('--render_interval', type=int, default=100, help='interval between rendered epochs (default: 100)') # Learning parameters parser.add_argument('--gamma', type=float, default=0.99, help='discount factor (default: 0.99)') parser.add_argument('--lr', type=float, default=0.01, help='learning rate (default: 0.01)') parser.add_argument('--no-cuda', action='store_true', default=False, help='Enables CUDA training') parser.add_argument('--eb', type=int, default=1, help='episode batch (default: 1)') parser.add_argument('--episodes', type=int, default=10000, help='simulated episodes (default: 10000)') parser.add_argument('--policy', type=str, default=None, help="""Policy checkpoint to restore.""") parser.add_argument('--seed', type=int, default=42, help='random seed (default: 42)') parser.add_argument('--horizon', type=int, default=1000, help='horizon (default: 1000)') parser.add_argument('--baseline', action='store_true', help='use the baseline for the REINFORCE algorithm') args = parser.parse_args() # Check cuda args.cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda" if args.cuda else "cpu") # Initialize environment env = gym.make('CarRacing-v0') env = CropCarRacing(env) env = ResizeObservation(env, (32, 32, 3)) env = Scolorized(env, weights=[0.0, 1.0, 0.0]) env = NormalizeRGB(env) env.seed(args.seed) torch.manual_seed(args.seed) print("Env final goal:", env.spec.reward_threshold) # Create the alias for the run alias = '%s_%s' % (args.alias, time.time()) # Use alias for checkpoints checkpoint_best_filename = 'policy_weights/' + alias + '_best.torch' checkpoint_final_filename = 'policy_weights/' + alias + '_final.torch' if not os.path.exists('weights/'): os.makedirs('weights/') # Tensorboard writer writer = SummaryWriter('logs/' + alias) # Create VAE policy vape = VAEPolicy() optimizer = optim.Adam(vape.parameters(), lr=1e-04) # Animation of environment obs = env.reset() obs_torch = torch.from_numpy(NCHW([obs])).float().to(device) rebuild = vape.encode_decode(obs_torch) rebuild = NHWC(rebuild.detach().numpy()[0]) fig1 = plt.figure() if len(obs.shape) == 3 and (obs.shape[-1] == 1): im = plt.imshow(side_by_side(obs, rebuild), cmap="Greys") else: im = plt.imshow(side_by_side(obs, rebuild)) done = False HORIZON = 200 timestep = 0 # Setting animation update function def updatefig(*args): nonlocal done nonlocal obs nonlocal HORIZON nonlocal timestep obs_torch = torch.from_numpy(NCHW([obs])).float().to(device) if not done and timestep < HORIZON: action, action_proba = vape.act(obs_torch) action = action[0].detach().numpy() obs, reward, done, info = env.step(action) env.render(mode='human') timestep += 1 else: done = False obs = env.reset() timestep = 0 rebuild = vape.encode_decode(obs_torch) rebuild = NHWC(rebuild.detach().numpy()[0]) im.set_array(side_by_side(obs, rebuild)) vape.optimize_vae(obs_torch, optimizer) time.sleep(0.01) return im, # Start animation ani = animation.FuncAnimation(fig1, updatefig, interval=50, blit=True) plt.show() # Close env and writer env.close() writer.close()
env = Scolorized(env, weights=[0.0, 1.0, 0.0]) env.seed(args.seed) # Network creation VAE_class = VAEbyArch(args.arch) vae = VAE_class(latent_size=args.latent_size).to(device) # Restore checkpoint assert args.vae, "No checkpoint provided." vae.load_state_dict(torch.load(args.vae)) vae.eval() if args.dataset: # Single observation display mu, log_sigma, z, rebuild = vae(dataset_torch[args.sample:args.sample+1]) rebuild = rebuild.detach().numpy()[0] imshow_bw_or_rgb(side_by_side(dataset[args.sample], NHWC(rebuild))) plt.show() else: # Check if we use a policy policy = None if args.policy and args.vae_old: policy_env = VAEObservation(env, args.vae_old, arch=args.arch) policy = Policy(policy_env) policy.load_state_dict(torch.load(args.policy)) policy.eval() vae_old = VAE_class(latent_size=args.latent_size).to(device) vae_old.load_state_dict(torch.load(args.vae_old)) vae_old.eval() # Animation of environment obs = env.reset() obs_torch = torch.from_numpy(NCHW([obs])).float().to(device)
def transform_observation(obs): obs_torch = torch.from_numpy(NCHW([obs])).float().to(device) rebuild = vape.encode_decode(obs_torch) rebuild = NHWC(rebuild.detach().numpy()[0]) return rebuild, obs_torch
env = ResizeObservation(env, (64, 64, 3)) env = NormalizeRGB(env) env = Scolorized(env, weights=[0.0, 1.0, 0.0]) env.seed(args.seed) vape = VAEPolicy() # Restore checkpoint assert args.controller, "No checkpoint provided." vape.load_state_dict(torch.load(args.controller)) vape.eval() # Animation of environment obs = env.reset() obs_torch = torch.from_numpy(NCHW([obs])).float().to(device) rebuild, mu, log_sigma, z = vape.encode_decode(obs_torch) rebuild = NHWC(rebuild.detach().numpy()[0]) fig1 = plt.figure() if len(obs.shape) == 3 and (obs.shape[-1] == 1): im = plt.imshow(side_by_side(obs, rebuild), cmap="Greys") else: im = plt.imshow(side_by_side(obs, rebuild)) done = False # Setting animation update function def updatefig(*args): global done global obs if not done: obs_torch = torch.from_numpy(NCHW([obs])).float().to(device) mu, _ = vape.encode(obs_torch)