Exemplo n.º 1
0
    def test_overriding_default_resource_request(self):
        config = DEFAULT_CONFIG.copy()
        config["model"]["fcnet_hiddens"] = [10]
        config["num_workers"] = 2
        # 3 Trials: Can only run 2 at a time (num_cpus=6; needed: 3).
        config["lr"] = tune.grid_search([0.1, 0.01, 0.001])
        config["env"] = "CartPole-v0"
        config["framework"] = "tf"

        class DefaultResourceRequest:
            @classmethod
            def default_resource_request(cls, config):
                head_bundle = {"CPU": 1, "GPU": 0}
                child_bundle = {"CPU": 1}
                return PlacementGroupFactory(
                    [head_bundle, child_bundle, child_bundle],
                    strategy=config["placement_strategy"])

        # Create a trainer with an overridden default_resource_request
        # method that returns a PlacementGroupFactory.
        MyTrainer = PGTrainer.with_updates(mixins=[DefaultResourceRequest])
        tune.register_trainable("my_trainable", MyTrainer)

        global trial_executor
        trial_executor = RayTrialExecutor(reuse_actors=False)

        tune.run(
            "my_trainable",
            config=config,
            stop={"training_iteration": 2},
            trial_executor=trial_executor,
            callbacks=[_TestCallback()],
            verbose=2,
        )
Exemplo n.º 2
0
def run_with_custom_entropy_loss(args, stop):
    """Example of customizing the loss function of an existing policy.

    This performs about the same as the default loss does."""
    def entropy_policy_gradient_loss(policy, model, dist_class, train_batch):
        logits, _ = model.from_batch(train_batch)
        action_dist = dist_class(logits, model)
        if args.torch:
            # required by PGTorchPolicy's stats fn.
            policy.pi_err = torch.tensor([0.0])
            return torch.mean(-0.1 * action_dist.entropy() -
                              (action_dist.logp(train_batch["actions"]) *
                               train_batch["advantages"]))
        else:
            return (-0.1 * action_dist.entropy() - tf.reduce_mean(
                action_dist.logp(train_batch["actions"]) *
                train_batch["advantages"]))

    policy_cls = PGTorchPolicy if args.torch else PGTFPolicy
    EntropyPolicy = policy_cls.with_updates(
        loss_fn=entropy_policy_gradient_loss)

    EntropyLossPG = PGTrainer.with_updates(
        name="EntropyPG", get_policy_class=lambda _: EntropyPolicy)

    run_heuristic_vs_learned(args, use_lstm=True, trainer=EntropyLossPG)
Exemplo n.º 3
0
def main(debug):
    ray.init(num_cpus=os.cpu_count(), num_gpus=0)

    stop = {"episodes_total": 10 if debug else 400}

    env_config = {
        "max_steps": 10,
        "players_ids": ["player_row", "player_col"],
    }

    policies = {
        env_config["players_ids"][0]:
        (None, IteratedBoSAndPD.OBSERVATION_SPACE,
         IteratedBoSAndPD.ACTION_SPACE, {}),
        env_config["players_ids"][1]:
        (None, IteratedBoSAndPD.OBSERVATION_SPACE,
         IteratedBoSAndPD.ACTION_SPACE, {})
    }

    rllib_config = {
        "env":
        IteratedBoSAndPD,
        "env_config":
        env_config,
        "num_gpus":
        0,
        "num_workers":
        1,
        "multiagent": {
            "policies": policies,
            "policy_mapping_fn": (lambda agent_id: agent_id),
        },
        "framework":
        "torch",
        "gamma":
        0.5,
        "callbacks":
        miscellaneous.merge_callbacks(
            log.get_logging_callbacks_class(),
            postprocessing.OverwriteRewardWtWelfareCallback),
    }

    MyPGTorchPolicy = PGTorchPolicy.with_updates(
        postprocess_fn=miscellaneous.merge_policy_postprocessing_fn(
            postprocessing.get_postprocessing_welfare_function(
                add_inequity_aversion_welfare=True,
                inequity_aversion_beta=1.0,
                inequity_aversion_alpha=0.0,
                inequity_aversion_gamma=1.0,
                inequity_aversion_lambda=0.5),
            pg_torch_policy.post_process_advantages))
    MyPGTrainer = PGTrainer.with_updates(default_policy=MyPGTorchPolicy,
                                         get_policy_class=None)
    tune_analysis = tune.run(MyPGTrainer,
                             stop=stop,
                             checkpoint_freq=10,
                             config=rllib_config)
    ray.shutdown()
    return tune_analysis