예제 #1
0
 def sample(self, n_episodes, vae_policy, render=False):
     trajectories = {
         'rewards': np.zeros((n_episodes, self.horizon)),
         'mask': np.zeros((n_episodes, self.horizon)),
         'logp': torch.zeros(n_episodes, self.horizon, dtype=torch.double)
     }
     losses_and_info = []
     for i in range(n_episodes):
         obs = self.env.reset()
         for t in range(self.horizon):
             # Transform obs for PyTorch and optimize the observation (save to VAE)
             obs_torch = torch.from_numpy(NCHW([obs])).float()
             results = vae_policy.optimize_vae(obs_torch)
             losses_and_info.append(results)
             # Act
             a, logp = vae_policy.act(obs_torch)
             obs, r, done, _ = self.env.step(a)
             # Add to trajectories
             trajectories['rewards'][i, t] = r
             trajectories['logp'][i, t] = logp
             trajectories['mask'][i, t] = 1
             if render:
                 self.env.render()
             if done:
                 break
     return trajectories, losses_and_info
예제 #2
0
 def _observation(self, obs):
     # First convert to torch notation and type
     obs_torch = torch.from_numpy(NCHW([obs])).float()
     if self.device is not None:
         obs_torch = obs_torch.to(self.device)
     # Get the compressed
     mu, _ = self.vae.encode(obs_torch)
     return mu.detach().numpy()[0]
예제 #3
0
 def updatefig(*args):
     global done
     global obs
     if not done:
         obs_torch = torch.from_numpy(NCHW([obs])).float().to(device)
         mu, _ = vape.encode(obs_torch)
         action, logp = vape.act(obs_torch)
         obs, reward, done, info = env.step(action)
         env.render(mode='human')
     else:
         done = False
         obs = env.reset()
     obs_torch = torch.from_numpy(NCHW([obs])).float().to(device)
     rebuild, mu, log_sigma, z = vape.encode_decode(obs_torch)
     rebuild = NHWC(rebuild.detach().numpy()[0])
     im.set_array(side_by_side(obs, rebuild))
     time.sleep(0.01)
     return im,
예제 #4
0
 def updatefig(*args):
     global done
     global obs
     if not done:
         if policy:
             obs_torch = torch.from_numpy(NCHW([obs])).float().to(device)
             mu, _ = vae.encode(obs_torch)
             action, action_proba = policy.act(mu.detach().numpy())
             action = action[0]
         else:
             action = env.action_space.sample()
             action = [action[0], 0.3, 0.0]
         obs, reward, done, info = env.step(action)
         env.render(mode='human')
     else:
         done = False
         obs = env.reset()
     obs_torch = torch.from_numpy(NCHW([obs])).float().to(device)
     mu, log_sigma, z, rebuild = vae(obs_torch)
     rebuild = NHWC(rebuild.detach().numpy()[0])
     im.set_array(side_by_side(obs, rebuild))
     time.sleep(0.01)
     return im,
예제 #5
0
 def updatefig(*args):
     nonlocal done
     nonlocal obs
     nonlocal HORIZON
     nonlocal timestep
     obs_torch = torch.from_numpy(NCHW([obs])).float().to(device)
     if not done and timestep < HORIZON:
         action, action_proba = vape.act(obs_torch)
         action = action[0].detach().numpy()
         obs, reward, done, info = env.step(action)
         env.render(mode='human')
         timestep += 1
     else:
         done = False
         obs = env.reset()
         timestep = 0
     rebuild = vape.encode_decode(obs_torch)
     rebuild = NHWC(rebuild.detach().numpy()[0])
     im.set_array(side_by_side(obs, rebuild))
     vape.optimize_vae(obs_torch, optimizer)
     time.sleep(0.01)
     return im,
예제 #6
0
def main():
    # Parse arguments
    parser = argparse.ArgumentParser(description='REINFORCE using PyTorch')
    # Logging
    parser.add_argument('--alias',
                        type=str,
                        default='base',
                        help="""Alias of the model.""")
    parser.add_argument('--render_interval',
                        type=int,
                        default=100,
                        help='interval between rendered epochs (default: 100)')
    # Learning parameters
    parser.add_argument('--gamma',
                        type=float,
                        default=0.99,
                        help='discount factor (default: 0.99)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.01,
                        help='learning rate (default: 0.01)')

    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='Enables CUDA training')
    parser.add_argument('--eb',
                        type=int,
                        default=1,
                        help='episode batch (default: 1)')
    parser.add_argument('--episodes',
                        type=int,
                        default=10000,
                        help='simulated episodes (default: 10000)')
    parser.add_argument('--policy',
                        type=str,
                        default=None,
                        help="""Policy checkpoint to restore.""")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help='random seed (default: 42)')
    parser.add_argument('--horizon',
                        type=int,
                        default=1000,
                        help='horizon (default: 1000)')
    parser.add_argument('--baseline',
                        action='store_true',
                        help='use the baseline for the REINFORCE algorithm')
    args = parser.parse_args()
    # Check cuda
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if args.cuda else "cpu")
    # Initialize environment
    env = gym.make('CarRacing-v0')
    env = CropCarRacing(env)
    env = ResizeObservation(env, (32, 32, 3))
    env = Scolorized(env, weights=[0.0, 1.0, 0.0])
    env = NormalizeRGB(env)
    env.seed(args.seed)
    torch.manual_seed(args.seed)
    print("Env final goal:", env.spec.reward_threshold)
    # Create the alias for the run
    alias = '%s_%s' % (args.alias, time.time())
    # Use alias for checkpoints
    checkpoint_best_filename = 'policy_weights/' + alias + '_best.torch'
    checkpoint_final_filename = 'policy_weights/' + alias + '_final.torch'
    if not os.path.exists('weights/'):
        os.makedirs('weights/')
    # Tensorboard writer
    writer = SummaryWriter('logs/' + alias)
    # Create VAE policy
    vape = VAEPolicy()
    optimizer = optim.Adam(vape.parameters(), lr=1e-04)

    # Animation of environment
    obs = env.reset()
    obs_torch = torch.from_numpy(NCHW([obs])).float().to(device)
    rebuild = vape.encode_decode(obs_torch)
    rebuild = NHWC(rebuild.detach().numpy()[0])

    fig1 = plt.figure()
    if len(obs.shape) == 3 and (obs.shape[-1] == 1):
        im = plt.imshow(side_by_side(obs, rebuild), cmap="Greys")
    else:
        im = plt.imshow(side_by_side(obs, rebuild))
    done = False
    HORIZON = 200
    timestep = 0

    # Setting animation update function
    def updatefig(*args):
        nonlocal done
        nonlocal obs
        nonlocal HORIZON
        nonlocal timestep
        obs_torch = torch.from_numpy(NCHW([obs])).float().to(device)
        if not done and timestep < HORIZON:
            action, action_proba = vape.act(obs_torch)
            action = action[0].detach().numpy()
            obs, reward, done, info = env.step(action)
            env.render(mode='human')
            timestep += 1
        else:
            done = False
            obs = env.reset()
            timestep = 0
        rebuild = vape.encode_decode(obs_torch)
        rebuild = NHWC(rebuild.detach().numpy()[0])
        im.set_array(side_by_side(obs, rebuild))
        vape.optimize_vae(obs_torch, optimizer)
        time.sleep(0.01)
        return im,

    # Start animation
    ani = animation.FuncAnimation(fig1, updatefig, interval=50, blit=True)
    plt.show()
    # Close env and writer
    env.close()
    writer.close()
예제 #7
0
    parser.add_argument('--dataset', type=str, default='', help="""Dataset file to load.""")
    parser.add_argument('--arch', type=str, default='base_car_racing', help="""Model architecture.""")
    parser.add_argument('--seed', type=int, default=42, help="""Seed used in the environment initialization.""")
    parser.add_argument('--no-cuda', action='store_true', default=False, help='Enables CUDA training')
    args, unparsed = parser.parse_known_args()
    # Check cuda
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if args.cuda else "cpu")
    # Loading the dataset
    if args.dataset:
        dataset = np.array(pickle.load(open(args.dataset, 'rb')))
        N_SAMPLES, W, H, CHANNELS = dataset.shape
        print("Dataset size:", N_SAMPLES)
        print("Channels:", CHANNELS)
        print("Image dim: (%d,%d)" % (W,H))
        dataset_torch = torch.from_numpy(NCHW(dataset)).float().to(device)
    else:
        print("Using gym environment directly.")
        env = gym.make('CarRacing-v0')
        env = CropCarRacing(env)
        env = ResizeObservation(env, (32, 32, 3))
        env = NormalizeRGB(env)
        env = Scolorized(env, weights=[0.0, 1.0, 0.0])
        env.seed(args.seed)

    # Network creation
    VAE_class = VAEbyArch(args.arch)
    vae = VAE_class(latent_size=args.latent_size).to(device)
    # Restore checkpoint
    assert args.vae, "No checkpoint provided."
    vae.load_state_dict(torch.load(args.vae))
예제 #8
0
 def transform_observation(obs):
     obs_torch = torch.from_numpy(NCHW([obs])).float().to(device)
     rebuild = vape.encode_decode(obs_torch)
     rebuild = NHWC(rebuild.detach().numpy()[0])
     return rebuild, obs_torch