def updatefig(*args):
     nonlocal done
     nonlocal obs
     nonlocal HORIZON
     nonlocal timestep
     rebuild, obs_torch = transform_observation(obs)
     if not done and timestep < HORIZON:
         #action, action_proba = vape.act(obs_torch)
         #action = action[0].detach().numpy()
         action = env.action_space.sample()
         action = [action[0], 0.3, 0.0]
         obs, reward, done, info = env.step(action)
         #env.render(mode='human')
         timestep += 1
     else:
         done = False
         obs = env.reset()
         timestep = 0
     c = NHWC(vape.centroids.detach().numpy())
     im.set_array(
         side_by_side(
             side_by_side(side_by_side(obs, rebuild), c[0, :, :, 0]),
             c[1, :, :, 0]))
     vape.optimize_vae(obs_torch, optimizer)
     time.sleep(0.01)
     return im,
 def show_centroids():
     c = NHWC(vape.centroids.detach().numpy())
     fig, ax = plt.subplots(nrows=2, ncols=2)
     ax[0, 0].imshow(c[0, :, :, 0], cmap='Greys')
     ax[0, 1].imshow(c[1, :, :, 0], cmap='Greys')
     ax[1, 0].imshow(c[2, :, :, 0], cmap='Greys')
     ax[1, 1].imshow(c[3, :, :, 0], cmap='Greys')
     plt.show()
Beispiel #3
0
 def updatefig(*args):
     global done
     global obs
     if not done:
         obs_torch = torch.from_numpy(NCHW([obs])).float().to(device)
         mu, _ = vape.encode(obs_torch)
         action, logp = vape.act(obs_torch)
         obs, reward, done, info = env.step(action)
         env.render(mode='human')
     else:
         done = False
         obs = env.reset()
     obs_torch = torch.from_numpy(NCHW([obs])).float().to(device)
     rebuild, mu, log_sigma, z = vape.encode_decode(obs_torch)
     rebuild = NHWC(rebuild.detach().numpy()[0])
     im.set_array(side_by_side(obs, rebuild))
     time.sleep(0.01)
     return im,
Beispiel #4
0
 def updatefig(*args):
     nonlocal done
     nonlocal obs
     nonlocal HORIZON
     nonlocal timestep
     obs_torch = torch.from_numpy(NCHW([obs])).float().to(device)
     if not done and timestep < HORIZON:
         action, action_proba = vape.act(obs_torch)
         action = action[0].detach().numpy()
         obs, reward, done, info = env.step(action)
         env.render(mode='human')
         timestep += 1
     else:
         done = False
         obs = env.reset()
         timestep = 0
     rebuild = vape.encode_decode(obs_torch)
     rebuild = NHWC(rebuild.detach().numpy()[0])
     im.set_array(side_by_side(obs, rebuild))
     vape.optimize_vae(obs_torch, optimizer)
     time.sleep(0.01)
     return im,
Beispiel #5
0
 def updatefig(*args):
     global done
     global obs
     if not done:
         if policy:
             obs_torch = torch.from_numpy(NCHW([obs])).float().to(device)
             mu, _ = vae.encode(obs_torch)
             action, action_proba = policy.act(mu.detach().numpy())
             action = action[0]
         else:
             action = env.action_space.sample()
             action = [action[0], 0.3, 0.0]
         obs, reward, done, info = env.step(action)
         env.render(mode='human')
     else:
         done = False
         obs = env.reset()
     obs_torch = torch.from_numpy(NCHW([obs])).float().to(device)
     mu, log_sigma, z, rebuild = vae(obs_torch)
     rebuild = NHWC(rebuild.detach().numpy()[0])
     im.set_array(side_by_side(obs, rebuild))
     time.sleep(0.01)
     return im,
Beispiel #6
0
def main():
    # Parse arguments
    parser = argparse.ArgumentParser(description='REINFORCE using PyTorch')
    # Logging
    parser.add_argument('--alias',
                        type=str,
                        default='base',
                        help="""Alias of the model.""")
    parser.add_argument('--render_interval',
                        type=int,
                        default=100,
                        help='interval between rendered epochs (default: 100)')
    # Learning parameters
    parser.add_argument('--gamma',
                        type=float,
                        default=0.99,
                        help='discount factor (default: 0.99)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.01,
                        help='learning rate (default: 0.01)')

    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='Enables CUDA training')
    parser.add_argument('--eb',
                        type=int,
                        default=1,
                        help='episode batch (default: 1)')
    parser.add_argument('--episodes',
                        type=int,
                        default=10000,
                        help='simulated episodes (default: 10000)')
    parser.add_argument('--policy',
                        type=str,
                        default=None,
                        help="""Policy checkpoint to restore.""")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help='random seed (default: 42)')
    parser.add_argument('--horizon',
                        type=int,
                        default=1000,
                        help='horizon (default: 1000)')
    parser.add_argument('--baseline',
                        action='store_true',
                        help='use the baseline for the REINFORCE algorithm')
    args = parser.parse_args()
    # Check cuda
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if args.cuda else "cpu")
    # Initialize environment
    env = gym.make('CarRacing-v0')
    env = CropCarRacing(env)
    env = ResizeObservation(env, (32, 32, 3))
    env = Scolorized(env, weights=[0.0, 1.0, 0.0])
    env = NormalizeRGB(env)
    env.seed(args.seed)
    torch.manual_seed(args.seed)
    print("Env final goal:", env.spec.reward_threshold)
    # Create the alias for the run
    alias = '%s_%s' % (args.alias, time.time())
    # Use alias for checkpoints
    checkpoint_best_filename = 'policy_weights/' + alias + '_best.torch'
    checkpoint_final_filename = 'policy_weights/' + alias + '_final.torch'
    if not os.path.exists('weights/'):
        os.makedirs('weights/')
    # Tensorboard writer
    writer = SummaryWriter('logs/' + alias)
    # Create VAE policy
    vape = VAEPolicy()
    optimizer = optim.Adam(vape.parameters(), lr=1e-04)

    # Animation of environment
    obs = env.reset()
    obs_torch = torch.from_numpy(NCHW([obs])).float().to(device)
    rebuild = vape.encode_decode(obs_torch)
    rebuild = NHWC(rebuild.detach().numpy()[0])

    fig1 = plt.figure()
    if len(obs.shape) == 3 and (obs.shape[-1] == 1):
        im = plt.imshow(side_by_side(obs, rebuild), cmap="Greys")
    else:
        im = plt.imshow(side_by_side(obs, rebuild))
    done = False
    HORIZON = 200
    timestep = 0

    # Setting animation update function
    def updatefig(*args):
        nonlocal done
        nonlocal obs
        nonlocal HORIZON
        nonlocal timestep
        obs_torch = torch.from_numpy(NCHW([obs])).float().to(device)
        if not done and timestep < HORIZON:
            action, action_proba = vape.act(obs_torch)
            action = action[0].detach().numpy()
            obs, reward, done, info = env.step(action)
            env.render(mode='human')
            timestep += 1
        else:
            done = False
            obs = env.reset()
            timestep = 0
        rebuild = vape.encode_decode(obs_torch)
        rebuild = NHWC(rebuild.detach().numpy()[0])
        im.set_array(side_by_side(obs, rebuild))
        vape.optimize_vae(obs_torch, optimizer)
        time.sleep(0.01)
        return im,

    # Start animation
    ani = animation.FuncAnimation(fig1, updatefig, interval=50, blit=True)
    plt.show()
    # Close env and writer
    env.close()
    writer.close()
Beispiel #7
0
        env = Scolorized(env, weights=[0.0, 1.0, 0.0])
        env.seed(args.seed)

    # Network creation
    VAE_class = VAEbyArch(args.arch)
    vae = VAE_class(latent_size=args.latent_size).to(device)
    # Restore checkpoint
    assert args.vae, "No checkpoint provided."
    vae.load_state_dict(torch.load(args.vae))
    vae.eval()

    if args.dataset:
        # Single observation display
        mu, log_sigma, z, rebuild = vae(dataset_torch[args.sample:args.sample+1])
        rebuild = rebuild.detach().numpy()[0]
        imshow_bw_or_rgb(side_by_side(dataset[args.sample], NHWC(rebuild)))
        plt.show()
    else:
        # Check if we use a policy
        policy = None
        if args.policy and args.vae_old:
            policy_env = VAEObservation(env, args.vae_old, arch=args.arch)
            policy = Policy(policy_env)
            policy.load_state_dict(torch.load(args.policy))
            policy.eval()
            vae_old = VAE_class(latent_size=args.latent_size).to(device)
            vae_old.load_state_dict(torch.load(args.vae_old))
            vae_old.eval()
        # Animation of environment
        obs = env.reset()
        obs_torch = torch.from_numpy(NCHW([obs])).float().to(device)
 def transform_observation(obs):
     obs_torch = torch.from_numpy(NCHW([obs])).float().to(device)
     rebuild = vape.encode_decode(obs_torch)
     rebuild = NHWC(rebuild.detach().numpy()[0])
     return rebuild, obs_torch
Beispiel #9
0
    env = ResizeObservation(env, (64, 64, 3))
    env = NormalizeRGB(env)
    env = Scolorized(env, weights=[0.0, 1.0, 0.0])
    env.seed(args.seed)

    vape = VAEPolicy()
    # Restore checkpoint
    assert args.controller, "No checkpoint provided."
    vape.load_state_dict(torch.load(args.controller))
    vape.eval()

    # Animation of environment
    obs = env.reset()
    obs_torch = torch.from_numpy(NCHW([obs])).float().to(device)
    rebuild, mu, log_sigma, z = vape.encode_decode(obs_torch)
    rebuild = NHWC(rebuild.detach().numpy()[0])

    fig1 = plt.figure()
    if len(obs.shape) == 3 and (obs.shape[-1] == 1):
        im = plt.imshow(side_by_side(obs, rebuild), cmap="Greys")
    else:
        im = plt.imshow(side_by_side(obs, rebuild))
    done = False

    # Setting animation update function
    def updatefig(*args):
        global done
        global obs
        if not done:
            obs_torch = torch.from_numpy(NCHW([obs])).float().to(device)
            mu, _ = vape.encode(obs_torch)