Ejemplo n.º 1
0
 def setup(self, config):
     # Call super's `setup` to create rollout workers.
     super().setup(config)
     # Create local replay buffer.
     self.local_replay_buffer = MultiAgentReplayBuffer(num_shards=1,
                                                       learning_starts=1000,
                                                       capacity=50000,
                                                       replay_batch_size=64)
Ejemplo n.º 2
0
 def setup(self, config: PartialTrainerConfigDict):
     super().setup(config)
     # `training_iteration` implementation: Setup buffer in `setup`, not
     # in `execution_plan` (deprecated).
     if self.config["_disable_execution_plan_api"] is True:
         self.local_replay_buffer = MultiAgentReplayBuffer(
             learning_starts=self.config["learning_starts"],
             capacity=self.config["replay_buffer_size"],
             replay_batch_size=self.config["train_batch_size"],
             replay_sequence_length=1,
         )
Ejemplo n.º 3
0
    def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
                       **kwargs) -> LocalIterator[dict]:
        assert len(kwargs) == 0, (
            "Marwill execution_plan does NOT take any additional parameters")

        rollouts = ParallelRollouts(workers, mode="bulk_sync")
        replay_buffer = MultiAgentReplayBuffer(
            learning_starts=config["learning_starts"],
            capacity=config["replay_buffer_size"],
            replay_batch_size=config["train_batch_size"],
            replay_sequence_length=1,
        )

        store_op = rollouts \
            .for_each(StoreToReplayBuffer(local_buffer=replay_buffer))

        replay_op = Replay(local_buffer=replay_buffer) \
            .combine(
            ConcatBatches(
                min_batch_size=config["train_batch_size"],
                count_steps_by=config["multiagent"]["count_steps_by"],
            )) \
            .for_each(TrainOneStep(workers))

        train_op = Concurrently([store_op, replay_op],
                                mode="round_robin",
                                output_indexes=[1])

        return StandardMetricsReporting(train_op, workers, config)
Ejemplo n.º 4
0
def custom_training_workflow(workers: WorkerSet, config: dict):
    local_replay_buffer = MultiAgentReplayBuffer(num_shards=1,
                                                 learning_starts=1000,
                                                 capacity=50000,
                                                 replay_batch_size=64)

    def add_ppo_metrics(batch):
        print("PPO policy learning on samples from",
              batch.policy_batches.keys(), "env steps", batch.env_steps(),
              "agent steps", batch.env_steps())
        metrics = _get_shared_metrics()
        metrics.counters["agent_steps_trained_PPO"] += batch.env_steps()
        return batch

    def add_dqn_metrics(batch):
        print("DQN policy learning on samples from",
              batch.policy_batches.keys(), "env steps", batch.env_steps(),
              "agent steps", batch.env_steps())
        metrics = _get_shared_metrics()
        metrics.counters["agent_steps_trained_DQN"] += batch.env_steps()
        return batch

    # Generate common experiences.
    rollouts = ParallelRollouts(workers, mode="bulk_sync")
    r1, r2 = rollouts.duplicate(n=2)

    # DQN sub-flow.
    dqn_store_op = r1.for_each(SelectExperiences(["dqn_policy"])) \
        .for_each(
            StoreToReplayBuffer(local_buffer=local_replay_buffer))
    dqn_replay_op = Replay(local_buffer=local_replay_buffer) \
        .for_each(add_dqn_metrics) \
        .for_each(TrainOneStep(workers, policies=["dqn_policy"])) \
        .for_each(UpdateTargetNetwork(
            workers, target_update_freq=500, policies=["dqn_policy"]))
    dqn_train_op = Concurrently([dqn_store_op, dqn_replay_op],
                                mode="round_robin",
                                output_indexes=[1])

    # PPO sub-flow.
    ppo_train_op = r2.for_each(SelectExperiences(["ppo_policy"])) \
        .combine(ConcatBatches(
            min_batch_size=200, count_steps_by="env_steps")) \
        .for_each(add_ppo_metrics) \
        .for_each(StandardizeFields(["advantages"])) \
        .for_each(TrainOneStep(
            workers,
            policies=["ppo_policy"],
            num_sgd_iter=10,
            sgd_minibatch_size=128))

    # Combined training flow
    train_op = Concurrently([ppo_train_op, dqn_train_op],
                            mode="async",
                            output_indexes=[1])

    return StandardMetricsReporting(train_op, workers, config)
Ejemplo n.º 5
0
def test_store_to_replay_local(ray_start_regular_shared):
    buf = MultiAgentReplayBuffer(num_shards=1,
                                 learning_starts=200,
                                 capacity=1000,
                                 replay_batch_size=100,
                                 prioritized_replay_alpha=0.6,
                                 prioritized_replay_beta=0.4,
                                 prioritized_replay_eps=0.0001)
    assert buf.replay() is None

    workers = make_workers(0)
    a = ParallelRollouts(workers, mode="bulk_sync")
    b = a.for_each(StoreToReplayBuffer(local_buffer=buf))

    next(b)
    assert buf.replay() is None  # learning hasn't started yet
    next(b)
    assert buf.replay().count == 100

    replay_op = Replay(local_buffer=buf)
    assert next(replay_op).count == 100
Ejemplo n.º 6
0
def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
                   **kwargs) -> LocalIterator[dict]:
    """Execution plan of the MARWIL/BC algorithm. Defines the distributed
    dataflow.

    Args:
        workers (WorkerSet): The WorkerSet for training the Polic(y/ies)
            of the Trainer.
        config (TrainerConfigDict): The trainer's configuration dict.

    Returns:
        LocalIterator[dict]: A local iterator over training metrics.
    """
    assert len(kwargs) == 0, (
        "Marwill execution_plan does NOT take any additional parameters")

    rollouts = ParallelRollouts(workers, mode="bulk_sync")
    replay_buffer = MultiAgentReplayBuffer(
        learning_starts=config["learning_starts"],
        capacity=config["replay_buffer_size"],
        replay_batch_size=config["train_batch_size"],
        replay_sequence_length=1,
    )

    store_op = rollouts \
        .for_each(StoreToReplayBuffer(local_buffer=replay_buffer))

    replay_op = Replay(local_buffer=replay_buffer) \
        .combine(
            ConcatBatches(
                min_batch_size=config["train_batch_size"],
                count_steps_by=config["multiagent"]["count_steps_by"],
            )) \
        .for_each(TrainOneStep(workers))

    train_op = Concurrently(
        [store_op, replay_op], mode="round_robin", output_indexes=[1])

    return StandardMetricsReporting(train_op, workers, config)
Ejemplo n.º 7
0
    def test_ddpg_loss_function(self):
        """Tests DDPG loss function results across all frameworks."""
        config = ddpg.DEFAULT_CONFIG.copy()
        # Run locally.
        config["seed"] = 42
        config["num_workers"] = 0
        config["learning_starts"] = 0
        config["twin_q"] = True
        config["use_huber"] = True
        config["huber_threshold"] = 1.0
        config["gamma"] = 0.99
        # Make this small (seems to introduce errors).
        config["l2_reg"] = 1e-10
        config["prioritized_replay"] = False
        # Use very simple nets.
        config["actor_hiddens"] = [10]
        config["critic_hiddens"] = [10]
        # Make sure, timing differences do not affect trainer.train().
        config["min_time_s_per_reporting"] = 0
        config["timesteps_per_iteration"] = 100

        map_ = {
            # Normal net.
            "default_policy/actor_hidden_0/kernel":
            "policy_model.action_0."
            "_model.0.weight",
            "default_policy/actor_hidden_0/bias":
            "policy_model.action_0."
            "_model.0.bias",
            "default_policy/actor_out/kernel":
            "policy_model.action_out."
            "_model.0.weight",
            "default_policy/actor_out/bias":
            "policy_model.action_out._model.0.bias",
            "default_policy/sequential/q_hidden_0/kernel":
            "q_model.q_hidden_0"
            "._model.0.weight",
            "default_policy/sequential/q_hidden_0/bias":
            "q_model.q_hidden_0."
            "_model.0.bias",
            "default_policy/sequential/q_out/kernel":
            "q_model.q_out._model."
            "0.weight",
            "default_policy/sequential/q_out/bias":
            "q_model.q_out._model.0.bias",
            # -- twin.
            "default_policy/sequential_1/twin_q_hidden_0/kernel":
            "twin_"
            "q_model.twin_q_hidden_0._model.0.weight",
            "default_policy/sequential_1/twin_q_hidden_0/bias":
            "twin_"
            "q_model.twin_q_hidden_0._model.0.bias",
            "default_policy/sequential_1/twin_q_out/kernel":
            "twin_"
            "q_model.twin_q_out._model.0.weight",
            "default_policy/sequential_1/twin_q_out/bias":
            "twin_"
            "q_model.twin_q_out._model.0.bias",
            # Target net.
            "default_policy/actor_hidden_0_1/kernel":
            "policy_model.action_0."
            "_model.0.weight",
            "default_policy/actor_hidden_0_1/bias":
            "policy_model.action_0."
            "_model.0.bias",
            "default_policy/actor_out_1/kernel":
            "policy_model.action_out."
            "_model.0.weight",
            "default_policy/actor_out_1/bias":
            "policy_model.action_out._model"
            ".0.bias",
            "default_policy/sequential_2/q_hidden_0/kernel":
            "q_model."
            "q_hidden_0._model.0.weight",
            "default_policy/sequential_2/q_hidden_0/bias":
            "q_model."
            "q_hidden_0._model.0.bias",
            "default_policy/sequential_2/q_out/kernel":
            "q_model."
            "q_out._model.0.weight",
            "default_policy/sequential_2/q_out/bias":
            "q_model.q_out._model.0.bias",
            # -- twin.
            "default_policy/sequential_3/twin_q_hidden_0/kernel":
            "twin_"
            "q_model.twin_q_hidden_0._model.0.weight",
            "default_policy/sequential_3/twin_q_hidden_0/bias":
            "twin_"
            "q_model.twin_q_hidden_0._model.0.bias",
            "default_policy/sequential_3/twin_q_out/kernel":
            "twin_"
            "q_model.twin_q_out._model.0.weight",
            "default_policy/sequential_3/twin_q_out/bias":
            "twin_"
            "q_model.twin_q_out._model.0.bias",
        }

        env = SimpleEnv
        batch_size = 100
        obs_size = (batch_size, 1)
        actions = np.random.random(size=(batch_size, 1))

        # Batch of size=n.
        input_ = self._get_batch_helper(obs_size, actions, batch_size)

        # Simply compare loss values AND grads of all frameworks with each
        # other.
        prev_fw_loss = weights_dict = None
        expect_c, expect_a, expect_t = None, None, None
        # History of tf-updated NN-weights over n training steps.
        tf_updated_weights = []
        # History of input batches used.
        tf_inputs = []
        for fw, sess in framework_iterator(config,
                                           frameworks=("tf", "torch"),
                                           session=True):
            # Generate Trainer and get its default Policy object.
            trainer = ddpg.DDPGTrainer(config=config, env=env)
            policy = trainer.get_policy()
            p_sess = None
            if sess:
                p_sess = policy.get_session()

            # Set all weights (of all nets) to fixed values.
            if weights_dict is None:
                assert fw == "tf"  # Start with the tf vars-dict.
                weights_dict = policy.get_weights()
            else:
                assert fw == "torch"  # Then transfer that to torch Model.
                model_dict = self._translate_weights_to_torch(
                    weights_dict, map_)
                policy.model.load_state_dict(model_dict)
                policy.target_model.load_state_dict(model_dict)

            if fw == "torch":
                # Actually convert to torch tensors.
                input_ = policy._lazy_tensor_dict(input_)
                input_ = {k: input_[k] for k in input_.keys()}

            # Only run the expectation once, should be the same anyways
            # for all frameworks.
            if expect_c is None:
                expect_c, expect_a, expect_t = self._ddpg_loss_helper(
                    input_,
                    weights_dict,
                    sorted(weights_dict.keys()),
                    fw,
                    gamma=config["gamma"],
                    huber_threshold=config["huber_threshold"],
                    l2_reg=config["l2_reg"],
                    sess=sess,
                )

            # Get actual outs and compare to expectation AND previous
            # framework. c=critic, a=actor, e=entropy, t=td-error.
            if fw == "tf":
                c, a, t, tf_c_grads, tf_a_grads = p_sess.run(
                    [
                        policy.critic_loss,
                        policy.actor_loss,
                        policy.td_error,
                        policy._critic_optimizer.compute_gradients(
                            policy.critic_loss, policy.model.q_variables()),
                        policy._actor_optimizer.compute_gradients(
                            policy.actor_loss,
                            policy.model.policy_variables()),
                    ],
                    feed_dict=policy._get_loss_inputs_dict(input_,
                                                           shuffle=False),
                )
                # Check pure loss values.
                check(c, expect_c)
                check(a, expect_a)
                check(t, expect_t)

                tf_c_grads = [g for g, v in tf_c_grads]
                tf_a_grads = [g for g, v in tf_a_grads]

            elif fw == "torch":
                loss_torch(policy, policy.model, None, input_)
                c, a, t = (
                    policy.get_tower_stats("critic_loss")[0],
                    policy.get_tower_stats("actor_loss")[0],
                    policy.get_tower_stats("td_error")[0],
                )
                # Check pure loss values.
                check(c, expect_c)
                check(a, expect_a)
                check(t, expect_t)

                # Test actor gradients.
                policy._actor_optimizer.zero_grad()
                assert all(v.grad is None for v in policy.model.q_variables())
                assert all(v.grad is None
                           for v in policy.model.policy_variables())
                a.backward()
                # `actor_loss` depends on Q-net vars
                # (but not twin-Q-net vars!).
                assert not any(v.grad is None
                               for v in policy.model.q_variables()[:4])
                assert all(v.grad is None
                           for v in policy.model.q_variables()[4:])
                assert not all(
                    torch.mean(v.grad) == 0
                    for v in policy.model.policy_variables())
                assert not all(
                    torch.min(v.grad) == 0
                    for v in policy.model.policy_variables())
                # Compare with tf ones.
                torch_a_grads = [
                    v.grad for v in policy.model.policy_variables()
                ]
                for tf_g, torch_g in zip(tf_a_grads, torch_a_grads):
                    if tf_g.shape != torch_g.shape:
                        check(tf_g, np.transpose(torch_g.cpu()))
                    else:
                        check(tf_g, torch_g)

                # Test critic gradients.
                policy._critic_optimizer.zero_grad()
                assert all(v.grad is None or torch.mean(v.grad) == 0.0
                           for v in policy.model.q_variables())
                assert all(v.grad is None or torch.min(v.grad) == 0.0
                           for v in policy.model.q_variables())
                c.backward()
                assert not all(
                    torch.mean(v.grad) == 0
                    for v in policy.model.q_variables())
                assert not all(
                    torch.min(v.grad) == 0 for v in policy.model.q_variables())
                # Compare with tf ones.
                torch_c_grads = [v.grad for v in policy.model.q_variables()]
                for tf_g, torch_g in zip(tf_c_grads, torch_c_grads):
                    if tf_g.shape != torch_g.shape:
                        check(tf_g, np.transpose(torch_g.cpu()))
                    else:
                        check(tf_g, torch_g)
                # Compare (unchanged(!) actor grads) with tf ones.
                torch_a_grads = [
                    v.grad for v in policy.model.policy_variables()
                ]
                for tf_g, torch_g in zip(tf_a_grads, torch_a_grads):
                    if tf_g.shape != torch_g.shape:
                        check(tf_g, np.transpose(torch_g.cpu()))
                    else:
                        check(tf_g, torch_g)

            # Store this framework's losses in prev_fw_loss to compare with
            # next framework's outputs.
            if prev_fw_loss is not None:
                check(c, prev_fw_loss[0])
                check(a, prev_fw_loss[1])
                check(t, prev_fw_loss[2])

            prev_fw_loss = (c, a, t)

            # Update weights from our batch (n times).
            for update_iteration in range(6):
                print("train iteration {}".format(update_iteration))
                if fw == "tf":
                    in_ = self._get_batch_helper(obs_size, actions, batch_size)
                    tf_inputs.append(in_)
                    # Set a fake-batch to use
                    # (instead of sampling from replay buffer).
                    buf = MultiAgentReplayBuffer.get_instance_for_testing()
                    buf._fake_batch = in_
                    trainer.train()
                    updated_weights = policy.get_weights()
                    # Net must have changed.
                    if tf_updated_weights:
                        check(
                            updated_weights[
                                "default_policy/actor_hidden_0/kernel"],
                            tf_updated_weights[-1]
                            ["default_policy/actor_hidden_0/kernel"],
                            false=True,
                        )
                    tf_updated_weights.append(updated_weights)

                # Compare with updated tf-weights. Must all be the same.
                else:
                    tf_weights = tf_updated_weights[update_iteration]
                    in_ = tf_inputs[update_iteration]
                    # Set a fake-batch to use
                    # (instead of sampling from replay buffer).
                    buf = MultiAgentReplayBuffer.get_instance_for_testing()
                    buf._fake_batch = in_
                    trainer.train()
                    # Compare updated model and target weights.
                    for tf_key in tf_weights.keys():
                        tf_var = tf_weights[tf_key]
                        # Model.
                        if re.search(
                                "actor_out_1|actor_hidden_0_1|sequential_[23]",
                                tf_key):
                            torch_var = policy.target_model.state_dict()[
                                map_[tf_key]]
                        # Target model.
                        else:
                            torch_var = policy.model.state_dict()[map_[tf_key]]
                        if tf_var.shape != torch_var.shape:
                            check(tf_var,
                                  np.transpose(torch_var.cpu()),
                                  atol=0.1)
                        else:
                            check(tf_var, torch_var, atol=0.1)

            trainer.stop()
Ejemplo n.º 8
0
    def test_sac_loss_function(self):
        """Tests SAC loss function results across all frameworks."""
        config = sac.DEFAULT_CONFIG.copy()
        # Run locally.
        config["num_workers"] = 0
        config["learning_starts"] = 0
        config["twin_q"] = False
        config["gamma"] = 0.99
        # Switch on deterministic loss so we can compare the loss values.
        config["_deterministic_loss"] = True
        # Use very simple nets.
        config["Q_model"]["fcnet_hiddens"] = [10]
        config["policy_model"]["fcnet_hiddens"] = [10]
        # Make sure, timing differences do not affect trainer.train().
        config["min_time_s_per_reporting"] = 0
        # Test SAC with Simplex action space.
        config["env_config"] = {"simplex_actions": True}

        map_ = {
            # Action net.
            "default_policy/fc_1/kernel": "action_model._hidden_layers.0."
            "_model.0.weight",
            "default_policy/fc_1/bias": "action_model._hidden_layers.0."
            "_model.0.bias",
            "default_policy/fc_out/kernel": "action_model."
            "_logits._model.0.weight",
            "default_policy/fc_out/bias": "action_model._logits._model.0.bias",
            "default_policy/value_out/kernel": "action_model."
            "_value_branch._model.0.weight",
            "default_policy/value_out/bias": "action_model."
            "_value_branch._model.0.bias",
            # Q-net.
            "default_policy/fc_1_1/kernel": "q_net."
            "_hidden_layers.0._model.0.weight",
            "default_policy/fc_1_1/bias": "q_net."
            "_hidden_layers.0._model.0.bias",
            "default_policy/fc_out_1/kernel": "q_net._logits._model.0.weight",
            "default_policy/fc_out_1/bias": "q_net._logits._model.0.bias",
            "default_policy/value_out_1/kernel": "q_net."
            "_value_branch._model.0.weight",
            "default_policy/value_out_1/bias": "q_net."
            "_value_branch._model.0.bias",
            "default_policy/log_alpha": "log_alpha",
            # Target action-net.
            "default_policy/fc_1_2/kernel": "action_model."
            "_hidden_layers.0._model.0.weight",
            "default_policy/fc_1_2/bias": "action_model."
            "_hidden_layers.0._model.0.bias",
            "default_policy/fc_out_2/kernel": "action_model."
            "_logits._model.0.weight",
            "default_policy/fc_out_2/bias": "action_model."
            "_logits._model.0.bias",
            "default_policy/value_out_2/kernel": "action_model."
            "_value_branch._model.0.weight",
            "default_policy/value_out_2/bias": "action_model."
            "_value_branch._model.0.bias",
            # Target Q-net
            "default_policy/fc_1_3/kernel": "q_net."
            "_hidden_layers.0._model.0.weight",
            "default_policy/fc_1_3/bias": "q_net."
            "_hidden_layers.0._model.0.bias",
            "default_policy/fc_out_3/kernel": "q_net."
            "_logits._model.0.weight",
            "default_policy/fc_out_3/bias": "q_net."
            "_logits._model.0.bias",
            "default_policy/value_out_3/kernel": "q_net."
            "_value_branch._model.0.weight",
            "default_policy/value_out_3/bias": "q_net."
            "_value_branch._model.0.bias",
            "default_policy/log_alpha_1": "log_alpha",
        }

        env = SimpleEnv
        batch_size = 100
        obs_size = (batch_size, 1)
        actions = np.random.random(size=(batch_size, 2))

        # Batch of size=n.
        input_ = self._get_batch_helper(obs_size, actions, batch_size)

        # Simply compare loss values AND grads of all frameworks with each
        # other.
        prev_fw_loss = weights_dict = None
        expect_c, expect_a, expect_e, expect_t = None, None, None, None
        # History of tf-updated NN-weights over n training steps.
        tf_updated_weights = []
        # History of input batches used.
        tf_inputs = []
        for fw, sess in framework_iterator(config,
                                           frameworks=("tf", "torch"),
                                           session=True):
            # Generate Trainer and get its default Policy object.
            trainer = sac.SACTrainer(config=config, env=env)
            policy = trainer.get_policy()
            p_sess = None
            if sess:
                p_sess = policy.get_session()

            # Set all weights (of all nets) to fixed values.
            if weights_dict is None:
                # Start with the tf vars-dict.
                assert fw in ["tf2", "tf", "tfe"]
                weights_dict = policy.get_weights()
                if fw == "tfe":
                    log_alpha = weights_dict[10]
                    weights_dict = self._translate_tfe_weights(
                        weights_dict, map_)
            else:
                assert fw == "torch"  # Then transfer that to torch Model.
                model_dict = self._translate_weights_to_torch(
                    weights_dict, map_)
                # Have to add this here (not a parameter in tf, but must be
                # one in torch, so it gets properly copied to the GPU(s)).
                model_dict["target_entropy"] = policy.model.target_entropy
                policy.model.load_state_dict(model_dict)
                policy.target_model.load_state_dict(model_dict)

            if fw == "tf":
                log_alpha = weights_dict["default_policy/log_alpha"]
            elif fw == "torch":
                # Actually convert to torch tensors (by accessing everything).
                input_ = policy._lazy_tensor_dict(input_)
                input_ = {k: input_[k] for k in input_.keys()}
                log_alpha = policy.model.log_alpha.detach().cpu().numpy()[0]

            # Only run the expectation once, should be the same anyways
            # for all frameworks.
            if expect_c is None:
                expect_c, expect_a, expect_e, expect_t = self._sac_loss_helper(
                    input_,
                    weights_dict,
                    sorted(weights_dict.keys()),
                    log_alpha,
                    fw,
                    gamma=config["gamma"],
                    sess=sess,
                )

            # Get actual outs and compare to expectation AND previous
            # framework. c=critic, a=actor, e=entropy, t=td-error.
            if fw == "tf":
                c, a, e, t, tf_c_grads, tf_a_grads, tf_e_grads = p_sess.run(
                    [
                        policy.critic_loss,
                        policy.actor_loss,
                        policy.alpha_loss,
                        policy.td_error,
                        policy.optimizer().compute_gradients(
                            policy.critic_loss[0],
                            [
                                v for v in policy.model.q_variables()
                                if "value_" not in v.name
                            ],
                        ),
                        policy.optimizer().compute_gradients(
                            policy.actor_loss,
                            [
                                v for v in policy.model.policy_variables()
                                if "value_" not in v.name
                            ],
                        ),
                        policy.optimizer().compute_gradients(
                            policy.alpha_loss, policy.model.log_alpha),
                    ],
                    feed_dict=policy._get_loss_inputs_dict(input_,
                                                           shuffle=False),
                )
                tf_c_grads = [g for g, v in tf_c_grads]
                tf_a_grads = [g for g, v in tf_a_grads]
                tf_e_grads = [g for g, v in tf_e_grads]

            elif fw == "tfe":
                with tf.GradientTape() as tape:
                    tf_loss(policy, policy.model, None, input_)
                c, a, e, t = (
                    policy.critic_loss,
                    policy.actor_loss,
                    policy.alpha_loss,
                    policy.td_error,
                )
                vars = tape.watched_variables()
                tf_c_grads = tape.gradient(c[0], vars[6:10])
                tf_a_grads = tape.gradient(a, vars[2:6])
                tf_e_grads = tape.gradient(e, vars[10])

            elif fw == "torch":
                loss_torch(policy, policy.model, None, input_)
                c, a, e, t = (
                    policy.get_tower_stats("critic_loss")[0],
                    policy.get_tower_stats("actor_loss")[0],
                    policy.get_tower_stats("alpha_loss")[0],
                    policy.get_tower_stats("td_error")[0],
                )

                # Test actor gradients.
                policy.actor_optim.zero_grad()
                assert all(v.grad is None for v in policy.model.q_variables())
                assert all(v.grad is None
                           for v in policy.model.policy_variables())
                assert policy.model.log_alpha.grad is None
                a.backward()
                # `actor_loss` depends on Q-net vars (but these grads must
                # be ignored and overridden in critic_loss.backward!).
                assert not all(
                    torch.mean(v.grad) == 0
                    for v in policy.model.policy_variables())
                assert not all(
                    torch.min(v.grad) == 0
                    for v in policy.model.policy_variables())
                assert policy.model.log_alpha.grad is None
                # Compare with tf ones.
                torch_a_grads = [
                    v.grad for v in policy.model.policy_variables()
                    if v.grad is not None
                ]
                check(tf_a_grads[2],
                      np.transpose(torch_a_grads[0].detach().cpu()))

                # Test critic gradients.
                policy.critic_optims[0].zero_grad()
                assert all(
                    torch.mean(v.grad) == 0.0
                    for v in policy.model.q_variables() if v.grad is not None)
                assert all(
                    torch.min(v.grad) == 0.0
                    for v in policy.model.q_variables() if v.grad is not None)
                assert policy.model.log_alpha.grad is None
                c[0].backward()
                assert not all(
                    torch.mean(v.grad) == 0
                    for v in policy.model.q_variables() if v.grad is not None)
                assert not all(
                    torch.min(v.grad) == 0
                    for v in policy.model.q_variables() if v.grad is not None)
                assert policy.model.log_alpha.grad is None
                # Compare with tf ones.
                torch_c_grads = [v.grad for v in policy.model.q_variables()]
                check(tf_c_grads[0],
                      np.transpose(torch_c_grads[2].detach().cpu()))
                # Compare (unchanged(!) actor grads) with tf ones.
                torch_a_grads = [
                    v.grad for v in policy.model.policy_variables()
                ]
                check(tf_a_grads[2],
                      np.transpose(torch_a_grads[0].detach().cpu()))

                # Test alpha gradient.
                policy.alpha_optim.zero_grad()
                assert policy.model.log_alpha.grad is None
                e.backward()
                assert policy.model.log_alpha.grad is not None
                check(policy.model.log_alpha.grad, tf_e_grads)

            check(c, expect_c)
            check(a, expect_a)
            check(e, expect_e)
            check(t, expect_t)

            # Store this framework's losses in prev_fw_loss to compare with
            # next framework's outputs.
            if prev_fw_loss is not None:
                check(c, prev_fw_loss[0])
                check(a, prev_fw_loss[1])
                check(e, prev_fw_loss[2])
                check(t, prev_fw_loss[3])

            prev_fw_loss = (c, a, e, t)

            # Update weights from our batch (n times).
            for update_iteration in range(5):
                print("train iteration {}".format(update_iteration))
                if fw == "tf":
                    in_ = self._get_batch_helper(obs_size, actions, batch_size)
                    tf_inputs.append(in_)
                    # Set a fake-batch to use
                    # (instead of sampling from replay buffer).
                    buf = MultiAgentReplayBuffer.get_instance_for_testing()
                    buf._fake_batch = in_
                    trainer.train()
                    updated_weights = policy.get_weights()
                    # Net must have changed.
                    if tf_updated_weights:
                        check(
                            updated_weights["default_policy/fc_1/kernel"],
                            tf_updated_weights[-1]
                            ["default_policy/fc_1/kernel"],
                            false=True,
                        )
                    tf_updated_weights.append(updated_weights)

                # Compare with updated tf-weights. Must all be the same.
                else:
                    tf_weights = tf_updated_weights[update_iteration]
                    in_ = tf_inputs[update_iteration]
                    # Set a fake-batch to use
                    # (instead of sampling from replay buffer).
                    buf = MultiAgentReplayBuffer.get_instance_for_testing()
                    buf._fake_batch = in_
                    trainer.train()
                    # Compare updated model.
                    for tf_key in sorted(tf_weights.keys()):
                        if re.search("_[23]|alpha", tf_key):
                            continue
                        tf_var = tf_weights[tf_key]
                        torch_var = policy.model.state_dict()[map_[tf_key]]
                        if tf_var.shape != torch_var.shape:
                            check(
                                tf_var,
                                np.transpose(torch_var.detach().cpu()),
                                atol=0.003,
                            )
                        else:
                            check(tf_var, torch_var, atol=0.003)
                    # And alpha.
                    check(policy.model.log_alpha,
                          tf_weights["default_policy/log_alpha"])
                    # Compare target nets.
                    for tf_key in sorted(tf_weights.keys()):
                        if not re.search("_[23]", tf_key):
                            continue
                        tf_var = tf_weights[tf_key]
                        torch_var = policy.target_model.state_dict()[
                            map_[tf_key]]
                        if tf_var.shape != torch_var.shape:
                            check(
                                tf_var,
                                np.transpose(torch_var.detach().cpu()),
                                atol=0.003,
                            )
                        else:
                            check(tf_var, torch_var, atol=0.003)
            trainer.stop()
Ejemplo n.º 9
0
class MyTrainer(Trainer):
    @classmethod
    @override(Trainer)
    def get_default_config(cls) -> TrainerConfigDict:
        # Run this Trainer with new `training_iteration` API and set some PPO-specific
        # parameters.
        return with_common_config({
            "_disable_execution_plan_api": True,
            "num_sgd_iter": 10,
            "sgd_minibatch_size": 128,
        })

    @override(Trainer)
    def setup(self, config):
        # Call super's `setup` to create rollout workers.
        super().setup(config)
        # Create local replay buffer.
        self.local_replay_buffer = MultiAgentReplayBuffer(num_shards=1,
                                                          learning_starts=1000,
                                                          capacity=50000,
                                                          replay_batch_size=64)

    @override(Trainer)
    def training_iteration(self) -> ResultDict:
        # Generate common experiences, collect batch for PPO, store every (DQN) batch
        # into replay buffer.
        ppo_batches = []
        num_env_steps = 0
        # PPO batch size fixed at 200.
        while num_env_steps < 200:
            ma_batches = synchronous_parallel_sample(self.workers)
            # Loop through (parallely collected) ma-batches.
            for ma_batch in ma_batches:
                # Update sampled counters.
                self._counters[NUM_ENV_STEPS_SAMPLED] += ma_batch.count
                self._counters[
                    NUM_AGENT_STEPS_SAMPLED] += ma_batch.agent_steps()
                ppo_batch = ma_batch.policy_batches.pop("ppo_policy")
                # Add collected batches (only for DQN policy) to replay buffer.
                self.local_replay_buffer.add_batch(ma_batch)

                ppo_batches.append(ppo_batch)
                num_env_steps += ppo_batch.count

        # DQN sub-flow.
        dqn_train_results = {}
        dqn_train_batch = self.local_replay_buffer.replay()
        if dqn_train_batch is not None:
            dqn_train_results = train_one_step(self, dqn_train_batch,
                                               ["dqn_policy"])
            self._counters[
                "agent_steps_trained_DQN"] += dqn_train_batch.agent_steps()
            print(
                "DQN policy learning on samples from",
                "agent steps trained",
                dqn_train_batch.agent_steps(),
            )
        # Update DQN's target net every 500 train steps.
        if (self._counters["agent_steps_trained_DQN"] -
                self._counters[LAST_TARGET_UPDATE_TS] >= 500):
            self.workers.local_worker().get_policy(
                "dqn_policy").update_target()
            self._counters[NUM_TARGET_UPDATES] += 1
            self._counters[LAST_TARGET_UPDATE_TS] = self._counters[
                "agent_steps_trained_DQN"]

        # PPO sub-flow.
        ppo_train_batch = SampleBatch.concat_samples(ppo_batches)
        self._counters[
            "agent_steps_trained_PPO"] += ppo_train_batch.agent_steps()
        # Standardize advantages.
        ppo_train_batch[Postprocessing.ADVANTAGES] = standardized(
            ppo_train_batch[Postprocessing.ADVANTAGES])
        print(
            "PPO policy learning on samples from",
            "agent steps trained",
            ppo_train_batch.agent_steps(),
        )
        ppo_train_batch = MultiAgentBatch({"ppo_policy": ppo_train_batch},
                                          ppo_train_batch.count)
        ppo_train_results = train_one_step(self, ppo_train_batch,
                                           ["ppo_policy"])

        # Combine results for PPO and DQN into one results dict.
        results = dict(ppo_train_results, **dqn_train_results)
        return results
Ejemplo n.º 10
0
class MARWILTrainer(Trainer):
    @classmethod
    @override(Trainer)
    def get_default_config(cls) -> TrainerConfigDict:
        return DEFAULT_CONFIG

    @override(Trainer)
    def validate_config(self, config: TrainerConfigDict) -> None:
        # Call super's validation method.
        super().validate_config(config)

        if config["num_gpus"] > 1:
            raise ValueError("`num_gpus` > 1 not yet supported for MARWIL!")

        if config["postprocess_inputs"] is False and config["beta"] > 0.0:
            raise ValueError(
                "`postprocess_inputs` must be True for MARWIL (to "
                "calculate accum., discounted returns)!")

    @override(Trainer)
    def get_default_policy_class(self,
                                 config: TrainerConfigDict) -> Type[Policy]:
        if config["framework"] == "torch":
            from ray.rllib.agents.marwil.marwil_torch_policy import MARWILTorchPolicy

            return MARWILTorchPolicy
        else:
            return MARWILTFPolicy

    @override(Trainer)
    def setup(self, config: PartialTrainerConfigDict):
        super().setup(config)
        # `training_iteration` implementation: Setup buffer in `setup`, not
        # in `execution_plan` (deprecated).
        if self.config["_disable_execution_plan_api"] is True:
            self.local_replay_buffer = MultiAgentReplayBuffer(
                learning_starts=self.config["learning_starts"],
                capacity=self.config["replay_buffer_size"],
                replay_batch_size=self.config["train_batch_size"],
                replay_sequence_length=1,
            )

    @override(Trainer)
    def training_iteration(self) -> ResultDict:
        # Collect SampleBatches from sample workers.
        batch = synchronous_parallel_sample(worker_set=self.workers)
        batch = batch.as_multi_agent()
        self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps()
        self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps()
        # Add batch to replay buffer.
        self.local_replay_buffer.add_batch(batch)

        # Pull batch from replay buffer and train on it.
        train_batch = self.local_replay_buffer.replay()
        # Train.
        if self.config["simple_optimizer"]:
            train_results = train_one_step(self, train_batch)
        else:
            train_results = multi_gpu_train_one_step(self, train_batch)
        self._counters[NUM_AGENT_STEPS_TRAINED] += batch.agent_steps()
        self._counters[NUM_ENV_STEPS_TRAINED] += batch.env_steps()

        global_vars = {
            "timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
        }

        # Update weights - after learning on the local worker - on all remote
        # workers.
        if self.workers.remote_workers():
            with self._timers[WORKER_UPDATE_TIMER]:
                self.workers.sync_weights(global_vars=global_vars)

        # Update global vars on local worker as well.
        self.workers.local_worker().set_global_vars(global_vars)

        return train_results

    @staticmethod
    @override(Trainer)
    def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
                       **kwargs) -> LocalIterator[dict]:
        assert (
            len(kwargs) == 0
        ), "Marwill execution_plan does NOT take any additional parameters"

        rollouts = ParallelRollouts(workers, mode="bulk_sync")
        replay_buffer = MultiAgentReplayBuffer(
            learning_starts=config["learning_starts"],
            capacity=config["replay_buffer_size"],
            replay_batch_size=config["train_batch_size"],
            replay_sequence_length=1,
        )

        store_op = rollouts.for_each(
            StoreToReplayBuffer(local_buffer=replay_buffer))

        replay_op = (Replay(local_buffer=replay_buffer).combine(
            ConcatBatches(
                min_batch_size=config["train_batch_size"],
                count_steps_by=config["multiagent"]["count_steps_by"],
            )).for_each(TrainOneStep(workers)))

        train_op = Concurrently([store_op, replay_op],
                                mode="round_robin",
                                output_indexes=[1])

        return StandardMetricsReporting(train_op, workers, config)