Пример #1
0
def test_store_to_replay_local(ray_start_regular_shared):
    buf = LocalReplayBuffer(num_shards=1,
                            learning_starts=200,
                            buffer_size=1000,
                            replay_batch_size=100,
                            prioritized_replay_alpha=0.6,
                            prioritized_replay_beta=0.4,
                            prioritized_replay_eps=0.0001)
    assert buf.replay() is None

    workers = make_workers(0)
    a = ParallelRollouts(workers, mode="bulk_sync")
    b = a.for_each(StoreToReplayBuffer(local_buffer=buf))

    next(b)
    assert buf.replay() is None  # learning hasn't started yet
    next(b)
    assert buf.replay().count == 100

    replay_op = Replay(local_buffer=buf)
    assert next(replay_op).count == 100
Пример #2
0
class DQNAgent:
    def __init__(
        self,
        env,
        batch_size,
        trace_length,
        grid_size,
        exploiter_base_lr,
        exploiter_decay_lr_in_n_epi,
        exploiter_stop_training_after_n_epi,
        train_exploiter_n_times_per_epi,
    ):

        self.stop_training_after_n_epi = exploiter_stop_training_after_n_epi
        self.train_exploiter_n_times_per_epi = train_exploiter_n_times_per_epi

        # with tf.variable_scope(f"dqn_exploiter"):
        # Create the dqn policy for the exploiter
        dqn_config = copy.deepcopy(DEFAULT_CONFIG)
        dqn_config.update({
            "prioritized_replay":
            False,
            "double_q":
            True,
            "buffer_size":
            50000,
            "dueling":
            False,
            "learning_starts":
            min(int((batch_size - 1) * (trace_length - 1)), 64),
            "model": {
                "dim": grid_size,
                "conv_filters": [[16, [3, 3], 1], [32, [3, 3], 1]],
                # [Channel, [Kernel, Kernel], Stride]]
                # "fcnet_hiddens": [self.env.NUM_ACTIONS],
                "max_seq_len": trace_length,
                # Number of hidden layers for fully connected net
                "fcnet_hiddens": [64],
                # Nonlinearity for fully connected net (tanh, relu)
                "fcnet_activation": "relu",
            },
            # Update the replay buffer with this many samples at once. Note that
            # this setting applies per-worker if num_workers > 1.
            "rollout_fragment_length":
            1,
            # Size of a batch sampled from replay buffer for training. Note that
            # if async_updates is set, then each worker returns gradients for a
            # batch of this size.
            "train_batch_size":
            min(int((batch_size) * (trace_length)), 64),
            "explore":
            False,
            "grad_clip":
            1,
            "gamma":
            0.5,
            "lr":
            exploiter_base_lr,
            # Learning rate schedule
            "lr_schedule": [
                (0, exploiter_base_lr / 1000),
                (100, exploiter_base_lr),
                (exploiter_decay_lr_in_n_epi, exploiter_base_lr / 1e9),
            ],
            "sgd_momentum":
            0.9,
        })
        print("dqn_config", dqn_config)

        self.local_replay_buffer = LocalReplayBuffer(
            num_shards=1,
            learning_starts=dqn_config["learning_starts"],
            buffer_size=dqn_config["buffer_size"],
            replay_batch_size=dqn_config["train_batch_size"],
            replay_mode=dqn_config["multiagent"]["replay_mode"],
            replay_sequence_length=dqn_config["replay_sequence_length"],
        )

        # self.dqn_exploiter = DQNTFPolicy(obs_space=self.env.OBSERVATION_SPACE,
        #                                  action_space=self.env.ACTION_SPACE,
        #                                  config=dqn_config)

        def sgd_optimizer_dqn(policy, config) -> "torch.optim.Optimizer":
            return torch.optim.SGD(
                policy.q_func_vars,
                lr=policy.cur_lr,
                momentum=config["sgd_momentum"],
            )

        MyDQNTorchPolicy = DQNTorchPolicy.with_updates(
            optimizer_fn=sgd_optimizer_dqn)
        self.dqn_policy = MyDQNTorchPolicy(
            obs_space=env.OBSERVATION_SPACE,
            action_space=env.ACTION_SPACE,
            config=dqn_config,
        )

        self.multi_agent_batch_builders = [
            MultiAgentSampleBatchBuilder(
                policy_map={"player_blue": self.dqn_policy},
                clip_rewards=False,
                callbacks=DefaultCallbacks(),
            )
            # for _ in range(self.batch_size)
        ]

    def compute_actions(self, obs_batch):
        action, a2, a3 = self.dqn_policy.compute_actions(obs_batch=obs_batch)
        return action, a2, a3

    def add_data_in_rllib_batch_builder(self, s, s1P, trainBatch1, d,
                                        timestep):
        if timestep <= self.stop_training_after_n_epi:
            # for i in range(self.batch_size):
            i = 0
            step_player_values = {
                "eps_id":
                timestep,
                "obs":
                s[i],
                "new_obs":
                s1P[i],
                "actions":
                trainBatch1[1][-1][i],
                "prev_actions":
                trainBatch1[1][-2][i] if len(trainBatch1[1]) > 1 else 0,
                "rewards":
                trainBatch1[2][-1][i],
                "prev_rewards":
                trainBatch1[2][-2][i] if len(trainBatch1[2]) > 1 else 0,
                "dones":
                d[0],
                # done is the same for for every episodes in the batch
            }
            self.multi_agent_batch_builders[i].add_values(
                agent_id="player_blue",
                policy_id="player_blue",
                **step_player_values,
            )
            self.multi_agent_batch_builders[i].count += 1

    def train_dqn_policy(self, timestep):
        stats = {"learner_stats": {}}
        if timestep <= self.stop_training_after_n_epi:
            # Add episodes in replay buffer
            # for i in range(self.batch_size):
            i = 0
            multiagent_batch = self.multi_agent_batch_builders[
                i].build_and_reset()
            self.local_replay_buffer.add_batch(multiagent_batch)

            # update lr in scheduler & in optimizer
            self.dqn_policy.on_global_var_update({"timestep": timestep})
            self.dqn_policy.optimizer()
            if hasattr(self.dqn_policy, "cur_lr"):
                for opt in self.dqn_policy._optimizers:
                    for p in opt.param_groups:
                        p["lr"] = self.dqn_policy.cur_lr
            # Generate training batch and train
            for _ in range(self.train_exploiter_n_times_per_epi):
                replay_batch = self.local_replay_buffer.replay()
                if (
                        replay_batch is not None
                ):  # is None when there is not enough step in the data buffer
                    stats = self.dqn_policy.learn_on_batch(
                        replay_batch.policy_batches["player_blue"])

        stats["learner_stats"]["exploiter_lr_cur"] = self.dqn_policy.cur_lr
        for j, opt in enumerate(self.dqn_policy._optimizers):
            stats["learner_stats"]["exploiter_lr_from_params"] = [
                p["lr"] for p in opt.param_groups
            ][0]
        return stats