コード例 #1
0
def test_trajectory_plotting():
    """Tests trajectory plotting"""
    from agents import MyopicAgent, OptimalAgent
    from gridworld.gridworld import GridworldMdp

    agent = OptimalAgent()
    mdp = GridworldMdp.generate_random(12, 12, pr_wall=0.1, pr_reward=0.1)
    agent.set_mdp(mdp)
    walls, reward, start = mdp.convert_to_numpy_input()
    myopic = MyopicAgent(horizon=10)
    _plot_reward_and_trajectories_helper(
        reward, reward, walls, start, myopic, OptimalAgent(), filename="trajectory.png"
    )
コード例 #2
0
        def check_model_equivalent(model, query, weights, mdp, num_iters):
            with tf.compat.v1.Session() as sess:
                sess.run(model.initialize_op)
                (qvals, ) = model.compute(['q_values'],
                                          sess,
                                          mdp,
                                          query,
                                          weight_inits=weights)

            agent = OptimalAgent(gamma=model.gamma, num_iters=num_iters)
            for i, proxy in enumerate(model.proxy_reward_space):
                for idx, val in zip(query, proxy):
                    mdp.rewards[idx] = val
                agent.set_mdp(mdp)
                check_qvals_equivalent(qvals[i], agent, mdp)
コード例 #3
0
def run_test(walls, reward):
    """Runs test on given walls & rewards
    walls, reward: 2d numpy arrays (numbers)"""

    agent = OptimalAgent(num_iters=num_iters)
    agent.set_mdp(mdp)
    true_values = castAgentValuesToNumpy(agent.values)

    wall_tf = tf.placeholder(shape=(imsize, imsize), dtype=tf.float32)
    reward_tf = tf.placeholder(tf.float32, shape=(imsize, imsize))
    q_vals = test_model(wall_tf, reward_tf, tf_value_iter_model)
    out = sess.run(q_vals, feed_dict={wall_tf: walls, reward_tf: reward})
    out = np.reshape(out, (imsize * imsize, 5))
    predicted_values = np.max(out, axis=1).reshape((imsize, imsize))

    compareValues(true_values, predicted_values)
    visualizeValueDiff(true_values, predicted_values)