def evaluate_agent(agent, episodes, return_trajectories=False, seed=1):
    env = CartPoleEnv()
    env.seed(seed)

    returns, trajectories = [], []
    for _ in range(episodes):
        states, actions, rewards = [], [], []
        state, terminal = env.reset(), False
        while not terminal:
            with torch.no_grad():
                policy, _ = agent(state)
                action = policy.logits.argmax(dim=-1)  # Pick action greedily
                state, reward, terminal = env.step(action)

                if return_trajectories:
                    states.append(state)
                    actions.append(action)
                rewards.append(reward)
        returns.append(sum(rewards))
        if return_trajectories:
            # Collect trajectory data (including terminal signal, which may be needed for offline learning)
            terminals = torch.cat(
                [torch.ones(len(rewards) - 1),
                 torch.zeros(1)])
            trajectories.append(
                dict(states=torch.cat(states),
                     actions=torch.cat(actions),
                     rewards=torch.tensor(rewards, dtype=torch.float32),
                     terminals=terminals))

    return (returns, trajectories) if return_trajectories else returns
Esempio n. 2
0
                    type=int,
                    default=5,
                    metavar='IE',
                    help='Imitation learning epochs')
parser.add_argument('--imitation-replay-size',
                    type=int,
                    default=1,
                    metavar='IRS',
                    help='Imitation learning trajectory replay size')
args = parser.parse_args()
torch.manual_seed(args.seed)
os.makedirs('results', exist_ok=True)

# Set up environment and models
env = CartPoleEnv()
env.seed(args.seed)
agent = ActorCritic(env.observation_space.shape[0], env.action_space.n,
                    args.hidden_size)
agent_optimiser = optim.RMSprop(agent.parameters(), lr=args.learning_rate)
if args.imitation:
    # Set up expert trajectories dataset
    expert_trajectories = torch.load('expert_trajectories.pth')
    expert_trajectories = {
        k: torch.cat([trajectory[k] for trajectory in expert_trajectories],
                     dim=0)
        for k in expert_trajectories[0].keys()
    }  # Flatten expert trajectories
    expert_trajectories = TransitionDataset(expert_trajectories)
    # Set up discriminator
    if args.imitation in ['AIRL', 'GAIL']:
        if args.imitation == 'AIRL':