Exemple #1
0
def training_workflow(config, reporter):
    # Setup policy and policy evaluation actors
    env = gym.make("CartPole-v0")
    policy = CustomPolicy(env.observation_space, env.action_space, {})
    workers = [
        PolicyEvaluator.as_remote().remote(lambda c: gym.make("CartPole-v0"),
                                           CustomPolicy)
        for _ in range(config["num_workers"])
    ]

    for _ in range(config["num_iters"]):
        # Broadcast weights to the policy evaluation workers
        weights = ray.put({"default_policy": policy.get_weights()})
        for w in workers:
            w.set_weights.remote(weights)

        # Gather a batch of samples
        T1 = SampleBatch.concat_samples(
            ray.get([w.sample.remote() for w in workers]))
        print("DEBUG* BATCH ************************")
        print(T1)
        print("DEBUG*************************")




        # Improve the policy using the T1 batch
        policy.learn_on_batch(T1)

        reporter(**collect_metrics(remote_evaluators=workers))
Exemple #2
0
def training_workflow(config, reporter):
    from gym.spaces import Box
    import numpy as np
    env_maker = get_env_maker(GazeboEnv)
    env, agents = env_maker(config['env_config'], return_agents=True)
    space = Box(low=-np.ones(2), high=np.ones(2))
    # pdb.set_trace()
    replay_buffers = {
        agent_id: ReplayBuffer(config.get('buffer_size', 1000))
        for agent_id in agents
    }
    policy = {
        k: (RandomPolicy, a.observation_space, a.action_space, {})
        for k, a in agents.items()
    }
    worker = RolloutWorker(lambda x: env,
                           policy=policy,
                           batch_steps=32,
                           policy_mapping_fn=lambda x: x,
                           episode_horizon=20)
    for i in range(config['num_iters']):
        T1 = SampleBatch.concat_samples([worker.sample()])
        for agent_id, batch in T1.policy_batches.items():
            for row in batch.rows():
                replay_buffers[agent_id].add(row['obs'],
                                             row['actions'],
                                             row['rewards'],
                                             row['new_obs'],
                                             row['dones'],
                                             weight=None)
    pdb.set_trace()
Exemple #3
0
 def testConcat(self):
     b1 = SampleBatch({"a": np.array([1, 2, 3]), "b": np.array([4, 5, 6])})
     b2 = SampleBatch({"a": np.array([1]), "b": np.array([4])})
     b3 = SampleBatch({"a": np.array([1]), "b": np.array([5])})
     b12 = b1.concat(b2)
     self.assertEqual(b12["a"].tolist(), [1, 2, 3, 1])
     self.assertEqual(b12["b"].tolist(), [4, 5, 6, 4])
     b = SampleBatch.concat_samples([b1, b2, b3])
     self.assertEqual(b["a"].tolist(), [1, 2, 3, 1, 1])
     self.assertEqual(b["b"].tolist(), [4, 5, 6, 4, 5])
Exemple #4
0
 def testConcat(self):
     b1 = SampleBatch({"a": np.array([1, 2, 3]), "b": np.array([4, 5, 6])})
     b2 = SampleBatch({"a": np.array([1]), "b": np.array([4])})
     b3 = SampleBatch({"a": np.array([1]), "b": np.array([5])})
     b12 = b1.concat(b2)
     self.assertEqual(b12["a"].tolist(), [1, 2, 3, 1])
     self.assertEqual(b12["b"].tolist(), [4, 5, 6, 4])
     b = SampleBatch.concat_samples([b1, b2, b3])
     self.assertEqual(b["a"].tolist(), [1, 2, 3, 1, 1])
     self.assertEqual(b["b"].tolist(), [4, 5, 6, 4, 5])
def training_workflow(config, reporter):
    # Setup policy and policy evaluation actors
    env = gym.make("CartPole-v0")
    policy = CustomPolicy(env.observation_space, env.action_space, {})
    workers = [
        PolicyEvaluator.as_remote().remote(lambda c: gym.make("CartPole-v0"),
                                           CustomPolicy)
        for _ in range(config["num_workers"])
    ]

    for _ in range(config["num_iters"]):
        # Broadcast weights to the policy evaluation workers
        weights = ray.put({"default_policy": policy.get_weights()})
        for w in workers:
            w.set_weights.remote(weights)

        # Gather a batch of samples
        T1 = SampleBatch.concat_samples(
            ray.get([w.sample.remote() for w in workers]))

        # Update the remote policy replicas and gather another batch of samples
        new_value = policy.w * 2.0
        for w in workers:
            w.for_policy.remote(lambda p: p.update_some_value(new_value))

        # Gather another batch of samples
        T2 = SampleBatch.concat_samples(
            ray.get([w.sample.remote() for w in workers]))

        # Improve the policy using the T1 batch
        policy.learn_on_batch(T1)

        # Do some arbitrary updates based on the T2 batch
        policy.update_some_value(sum(T2["rewards"]))

        reporter(**collect_metrics(remote_evaluators=workers))