Exemplo n.º 1
0
    def _generate_trials(self, experiment_spec, output_path=""):
        """Generates trials with configurations from `_suggest`.

        Creates a trial_id that is passed into `_suggest`.

        Yields:
            Trial objects constructed according to `spec`
        """
        if "run" not in experiment_spec:
            raise TuneError("Must specify `run` in {}".format(experiment_spec))
        for _ in range(experiment_spec.get("num_samples", 1)):
            trial_id = Trial.generate_id()
            while True:
                suggested_config = self._suggest(trial_id)
                if suggested_config is None:
                    yield None
                else:
                    break
            spec = copy.deepcopy(experiment_spec)
            spec["config"] = merge_dicts(spec["config"], suggested_config)
            flattened_config = resolve_nested_dict(spec["config"])
            self._counter += 1
            tag = "{0}_{1}".format(str(self._counter),
                                   format_vars(flattened_config))
            yield create_trial_from_spec(spec,
                                         output_path,
                                         self._parser,
                                         experiment_tag=tag,
                                         trial_id=trial_id)
Exemplo n.º 2
0
    def _generate_trials(self, experiment_spec, output_path=""):
        """Generates trials with configurations from `_suggest`.

        Creates a trial_id that is passed into `_suggest`.

        Yields:
            Trial objects constructed according to `spec`
        """
        if "run" not in experiment_spec:
            raise TuneError("Must specify `run` in {}".format(experiment_spec))
        for _ in range(experiment_spec.get("num_samples", 1)):
            trial_id = Trial.generate_id()
            while True:
                suggested_config = self._suggest(trial_id)
                if suggested_config is None:
                    yield None
                else:
                    break
            spec = copy.deepcopy(experiment_spec)
            spec["config"] = merge_dicts(spec["config"], suggested_config)
            flattened_config = resolve_nested_dict(spec["config"])
            self._counter += 1
            tag = "{0}_{1}".format(
                str(self._counter), format_vars(flattened_config))
            yield create_trial_from_spec(
                spec,
                output_path,
                self._parser,
                experiment_tag=tag,
                trial_id=trial_id)
def run(args, parser):
    config = {}
    # Load configuration from file
    config_dir = os.path.dirname(args.checkpoint)
    config_path = os.path.join(config_dir, "params.pkl")
    if not os.path.exists(config_path):
        config_path = os.path.join(config_dir, "../params.pkl")
    if not os.path.exists(config_path):
        if not args.config:
            raise ValueError(
                "Could not find params.pkl in either the checkpoint dir or "
                "its parent directory.")
    else:
        with open(config_path, 'rb') as f:
            config = pickle.load(f)
    if "num_workers" in config:
        config["num_workers"] = min(2, config["num_workers"])
    config = merge_dicts(config, args.config)
    if not args.env:
        if not config.get("env"):
            parser.error("the following arguments are required: --env")
        args.env = config.get("env")

    ray.init()

    cls = get_agent_class(args.run)
    agent = cls(env=args.env, config=config)
    agent.restore(args.checkpoint)
    num_steps = int(args.steps)
    rollout(agent, args.env, num_steps, args.out, args.no_render)
Exemplo n.º 4
0
def run(args, parser):
    config = {}
    # Load configuration from file
    config_dir = os.path.dirname(args.checkpoint)
    config_path = os.path.join(config_dir, "params.pkl")
    if not os.path.exists(config_path):
        config_path = os.path.join(config_dir, "../params.pkl")
    if not os.path.exists(config_path):
        if not args.config:
            raise ValueError(
                "Could not find params.pkl in either the checkpoint dir or "
                "its parent directory.")
    else:
        with open(config_path, 'rb') as f:
            config = pickle.load(f)
    if "num_workers" in config:
        config["num_workers"] = min(2, config["num_workers"])
    config = merge_dicts(config, args.config)
    if not args.env:
        if not config.get("env"):
            parser.error("the following arguments are required: --env")
        args.env = config.get("env")

    ray.init()

    cls = get_agent_class(args.run)
    agent = cls(env=args.env, config=config)
    agent.restore(args.checkpoint)
    num_steps = int(args.steps)
    rollout(agent, args.env, num_steps, args.out, args.no_render)
Exemplo n.º 5
0
def main():
    parser = create_parser_custom()
    args = parser.parse_args()

    config = {}
    # Load configuration from file
    config_dir = os.path.dirname(args.checkpoint)
    config_path = os.path.join(config_dir, "params.pkl")
    if not os.path.exists(config_path):
        config_path = os.path.join(config_dir, "../params.pkl")
    if not os.path.exists(config_path):
        if not args.config:
            raise ValueError(
                "Could not find params.pkl in either the checkpoint dir or "
                "its parent directory.")
    else:
        with open(config_path, "rb") as f:
            config = pickle.load(f)

    config = merge_dicts(config, args.config)
    if not args.env:
        if not config.get("env"):
            parser.error("the following arguments are required: --env")
        args.env = config.get("env")

    config['num_workers'] = 0
    config['num_gpus'] = 0
    config['num_envs_per_worker'] = 1

    # whether to run Doom env at it's default FPS (ASYNC mode)
    async_mode = args.fps == 0

    skip_frames = args.env_frameskip

    bot_difficulty = args.bot_difficulty

    record_to = join(args.record_to, f'{config["env"]}_{args._run}')

    custom_resolution = args.custom_res

    register_doom_envs_rllib(
        async_mode=async_mode,
        skip_frames=skip_frames,
        num_agents=args.num_agents,
        num_bots=args.num_bots,
        num_humans=args.num_humans,
        bot_difficulty=bot_difficulty,
        record_to=record_to,
        custom_resolution=custom_resolution,
    )

    ModelCatalog.register_custom_model('vizdoom_vision_model',
                                       VizdoomVisionNetwork)

    run(args, config)
Exemplo n.º 6
0
def run(args, parser):
    config = {}
    # Load configuration from file
    config_dir = os.path.dirname(args.checkpoint)
    config_path = os.path.join(config_dir, "params.pkl")
    if not os.path.exists(config_path):
        config_path = os.path.join(config_dir, "../params.pkl")
    if not os.path.exists(config_path):
        if not args.config:
            raise ValueError(
                "Could not find params.pkl in either the checkpoint dir or "
                "its parent directory.")
    else:
        with open(config_path, "rb") as f:
            config = pickle.load(f)
    if "num_workers" in config:
        config["num_workers"] = min(2, config["num_workers"])
    config = merge_dicts(config, args.config)
    if not args.env:
        if not config.get("env"):
            parser.error("the following arguments are required: --env")
        args.env = config.get("env")

    # remove unnecessary parameters
    if "num_workers" in config:
        del config["num_workers"]
    if "human_data_dir" in config["optimizer"]:
        del config["optimizer"]["human_data_dir"]
    if "human_demonstration" in config["optimizer"]:
        del config["optimizer"]["human_demonstration"]
    if "multiple_human_data" in config["optimizer"]:
        del config["optimizer"]["multiple_human_data"]
    if "num_replay_buffer_shards" in config["optimizer"]:
        del config["optimizer"]["num_replay_buffer_shards"]
    if "demonstration_zone_percentage" in config["optimizer"]:
        del config["optimizer"]["demonstration_zone_percentage"]
    if "dynamic_experience_replay" in config["optimizer"]:
        del config["optimizer"]["dynamic_experience_replay"]
    if "robot_demo_path" in config["optimizer"]:
        del config["optimizer"]["robot_demo_path"]

    ray.init()

    # cls = get_agent_class(args.run)
    # agent = cls(env=args.env, config=config)

    cls = get_agent_class("DDPG")
    agent = cls(env="ROBOTIC_ASSEMBLY", config=config)

    agent.restore(args.checkpoint)
    num_steps = int(args.steps)
    num_episodes = int(args.episodes)
    rollout(agent, args.env, num_steps, num_episodes, args.out)
Exemplo n.º 7
0
def run(args, parser):
    config = {}
    # Load configuration from file
    config_dir = os.path.dirname(args.checkpoint)
    config_path = os.path.join(config_dir, "params.pkl")
    if not os.path.exists(config_path):
        config_path = os.path.join(config_dir, "../params.pkl")
    if not os.path.exists(config_path):
        if not args.config:
            raise ValueError(
                "Could not find params.pkl in either the checkpoint dir or "
                "its parent directory.")
    else:
        with open(config_path, "rb") as f:
            config = pickle.load(f)
    if "num_workers" in config:
        config["num_workers"] = min(2, config["num_workers"])
    config = merge_dicts(config, args.config)
    if not args.env:
        if not config.get("env"):
            parser.error("the following arguments are required: --env")
        args.env = config.get("env")

    ray.init()

    cls = get_agent_class(args.run)
    agent = cls(env=args.env, config=config)
    agent.restore(args.checkpoint)
    num_steps = int(args.steps)
    num_episodes = int(args.episodes)
    with RolloutSaver(
            args.out,
            args.use_shelve,
            write_update_file=args.track_progress,
            target_steps=num_steps,
            target_episodes=num_episodes,
            save_info=args.save_info) as saver:
        rollout(agent, args.env, num_steps, num_episodes, saver,
                args.no_render, args.monitor)
Exemplo n.º 8
0
def get_config(args):
    config = {}
    # Load configuration from file
    config_dir = os.path.dirname(args.checkpoint)
    config_path = os.path.join(config_dir, "params.pkl")
    if not os.path.exists(config_path):
        config_path = os.path.join(config_dir, "../params.pkl")
    if not os.path.exists(config_path):
        if not args.config:
            raise ValueError(
                "Could not find params.pkl in either the checkpoint dir or "
                "its parent directory.")
    else:
        with open(config_path, "rb") as f:
            config = pickle.load(f)
    if "num_workers" in config:
        config["num_workers"] = min(2, config["num_workers"])
    config["num_gpus"] = 0
    config = merge_dicts(config, args.config)
    if not args.env:
        if not config.get("env"):
            raise ValueError("the following arguments are required: --env")
        args.env = config.get("env")
    return config
Exemplo n.º 9
0
def run_experiment(args, parser):
    # args.ray_object_store_memory = int(1e10)
    args.ray_redis_max_memory = int(2e9)

    if args.config_file:
        with open(args.config_file) as f:
            exp = yaml.load(f)
    else:
        raise Exception('No config file!')

    exp = merge_dicts(exp, args.config)
    log.info('Num workers: %d, num_envs_per_worker: %d',
             exp['config']['num_workers'],
             exp['config']['num_envs_per_worker'])

    if args.cfg_mixins is not None:
        for cfg_mixin_file in args.cfg_mixins:
            with open(cfg_mixin_file, 'r') as f:
                override_cfg = yaml.load(f)
                log.info('Overriding parameters from %s: %r', cfg_mixin_file,
                         override_cfg)
                exp = merge_dicts(exp, override_cfg)

    if not exp.get("run"):
        parser.error("the following arguments are required: --run")
    if not exp.get("env") and not exp.get("config", {}).get("env"):
        parser.error("the following arguments are required: --env")

    if args.ray_num_nodes:
        cluster = Cluster()
        for _ in range(args.ray_num_nodes):
            cluster.add_node(
                num_cpus=args.ray_num_cpus or 1,
                num_gpus=args.ray_num_gpus or 0,
                object_store_memory=args.ray_object_store_memory,
                redis_max_memory=args.ray_redis_max_memory,
            )
        ray.init(redis_address=cluster.redis_address,
                 local_mode=args.local_mode)
    else:
        ray.init(
            redis_address=args.redis_address,
            object_store_memory=args.ray_object_store_memory,
            redis_max_memory=args.ray_redis_max_memory,
            num_cpus=args.ray_num_cpus,
            num_gpus=args.ray_num_gpus,
            local_mode=args.local_mode,
        )

    exp = Experiment.from_json(args.experiment_name, exp)
    exp.spec['checkpoint_freq'] = 20
    if args.pbt:
        exp.spec['checkpoint_freq'] = 3

    exp.spec['checkpoint_at_end'] = True
    # exp.spec['checkpoint_score_attr'] = 'episode_reward_mean'
    exp.spec['keep_checkpoints_num'] = 5

    if args.stop_seconds > 0:
        exp.spec['stop'] = {'time_total_s': args.stop_seconds}

    # if 'multiagent' in exp.spec['config']:
    #     # noinspection PyProtectedMember
    #     make_env = ray.tune.registry._global_registry.get(ENV_CREATOR, exp.spec['config']['env'])
    #     temp_env = make_env(None)
    #     obs_space, action_space = temp_env.observation_space, temp_env.action_space
    #     temp_env.close()
    #     del temp_env
    #
    #     policies = dict(
    #         main=(None, obs_space, action_space, {}),
    #         dummy=(None, obs_space, action_space, {}),
    #     )
    #
    #     exp.spec['config']['multiagent'] = {
    #         'policies': policies,
    #         'policy_mapping_fn': function(lambda agent_id: 'main'),
    #         'policies_to_train': ['main'],
    #     }
    #
    # if args.dbg:
    #     exp.spec['config']['num_workers'] = 1
    #     exp.spec['config']['num_gpus'] = 1
    #     exp.spec['config']['num_envs_per_worker'] = 1
    #
    # if 'callbacks' not in exp.spec['config']:
    #     exp.spec['config']['callbacks'] = {}
    #
    # fps_helper = FpsHelper()
    #
    # def on_train_result(info):
    #     if 'APPO' in exp.spec['run']:
    #         samples = info['result']['info']['num_steps_sampled']
    #     else:
    #         samples = info['trainer'].optimizer.num_steps_trained
    #
    #     fps_helper.record(samples)
    #     fps = fps_helper.get_fps()
    #     info['result']['custom_metrics']['fps'] = fps
    #
    #     # remove this as currently
    #     skip_frames = exp.spec['config']['env_config']['skip_frames']
    #     info['result']['custom_metrics']['fps_frameskip'] = fps * skip_frames
    #
    # exp.spec['config']['callbacks']['on_train_result'] = function(on_train_result)
    #
    # def on_episode_end(info):
    #     episode = info['episode']
    #     stats = {
    #         'DEATHCOUNT': 0,
    #         'FRAGCOUNT': 0,
    #         'HITCOUNT': 0,
    #         'DAMAGECOUNT': 0,
    #         'KDR': 0,
    #         'FINAL_PLACE': 0,
    #         'LEADER_GAP': 0,
    #         'PLAYER_COUNT': 0,
    #         'BOT_DIFFICULTY': 0,
    #     }
    #
    #     # noinspection PyProtectedMember
    #     agent_to_last_info = episode._agent_to_last_info
    #     for agent in agent_to_last_info.keys():
    #         agent_info = agent_to_last_info[agent]
    #         for stats_key in stats.keys():
    #             stats[stats_key] += agent_info.get(stats_key, 0.0)
    #
    #     for stats_key in stats.keys():
    #         stats[stats_key] /= len(agent_to_last_info.keys())
    #
    #     episode.custom_metrics.update(stats)
    #
    # exp.spec['config']['callbacks']['on_episode_end'] = function(on_episode_end)

    extra_kwargs = {}
    if args.pbt:
        extra_kwargs['reuse_actors'] = False

    run(exp,
        name=args.experiment_name,
        scheduler=make_custom_scheduler(args),
        resume=args.resume,
        queue_trials=args.queue_trials,
        **extra_kwargs)
    # Notice that trial_max will only work for stochastic policies
    register_env(
        "ic20env", lambda _: SimplifiedIC20Environment(obs_state_processor,
                                                       act_state_processor,
                                                       UnstableReward(),
                                                       trial_max=10))
    ten_gig = 10737418240

    trainer = A2CTrainer(
        env="ic20env",
        config=merge_dicts(
            DEFAULT_CONFIG,
            {
                # -- Specific parameters
                'num_gpus': 0,
                'num_workers': 15,
                "num_envs_per_worker": 1,
                "num_cpus_per_worker": 1,
                "memory_per_worker": ten_gig,
                'gamma': 0.99,
            }))

    # Attempt to restore from checkpoint if possible.
    if os.path.exists(CHECKPOINT_FILE):
        checkpoint_path = open(CHECKPOINT_FILE).read()
        print("Restoring from checkpoint path", checkpoint_path)
        trainer.restore(checkpoint_path)

    # Serving and training loop
    while True:
        print(pretty_print(trainer.train()))
Exemplo n.º 11
0
    # Notice that trial_max will only work for stochastic policies
    register_env(
        "ic20env", lambda _: SimplifiedIC20Environment(obs_state_processor,
                                                       act_state_processor,
                                                       UnstableReward(),
                                                       trial_max=10))
    ten_gig = 10737418240
    trainer = ImpalaTrainer(
        env="ic20env",
        config=merge_dicts(
            DEFAULT_CONFIG,
            {
                # -- Rollout-Worker
                'num_gpus': 1,
                'num_workers': 15,
                "num_envs_per_worker": 1,
                "num_cpus_per_worker": 0.7,
                "memory_per_worker": ten_gig,

                # MDP
                'gamma': 0.99,
            }))

    # Attempt to restore from checkpoint if possible.
    if os.path.exists(CHECKPOINT_FILE):
        checkpoint_path = open(CHECKPOINT_FILE).read()
        print("Restoring from checkpoint path", checkpoint_path)
        trainer.restore(checkpoint_path)

    # Serving and training loop
    while True:
Exemplo n.º 12
0
    # Notice that trial_max will only work for stochastic policies
    register_env(
        "ic20env", lambda _: SimplifiedIC20Environment(obs_state_processor,
                                                       act_state_processor,
                                                       SparseReward(),
                                                       trial_max=10))
    ten_gig = 10737418240

    trainer = PPOTrainer(
        env="ic20env",
        config=merge_dicts(
            DEFAULT_CONFIG,
            {
                # -- Rollout-Worker
                'num_gpus': 1,
                'num_workers': 10,
                "num_envs_per_worker": 1,
                "num_cpus_per_worker": 1,
                "memory_per_worker": ten_gig,
                'gamma': 0.99,
                'lambda': 0.95
            }))

    # Attempt to restore from checkpoint if possible.
    if os.path.exists(CHECKPOINT_FILE):
        checkpoint_path = open(CHECKPOINT_FILE).read()
        print("Restoring from checkpoint path", checkpoint_path)
        trainer.restore(checkpoint_path)

    # Serving and training loop
    while True:
        print(pretty_print(trainer.train()))
        config=merge_dicts(DEFAULT_CONFIG, {
            # -- Rollout-Worker
            'num_gpus': 1,
            'num_workers': 15,
            "num_envs_per_worker": 1,
            "num_cpus_per_worker": 0.5,
            "memory_per_worker": ten_gig,

            # -- Specific parameters
            "use_gae": True,
            "kl_coeff": 0.2,
            "kl_target": 0.01,

            # GAE(gamma) parameter
            'lambda': 0.8,
            # Max global norm for each worker gradient
            'grad_clip': 40.0,
            'lr': 0.0001,
            'lr_schedule': [[100000, 0.00005], [1000000, 0.00001], [20000000, 0.0000001]],
            'vf_loss_coeff': 0.5,
            'entropy_coeff': 0.01,
            # MDP
            'gamma': 0.99,
            "clip_rewards": True,  # a2c_std: True

            # -- Batches
            "sample_batch_size": 1000,  # std: 200
            "train_batch_size": 4000,
            'batch_mode': 'complete_episodes',
            "sgd_minibatch_size": 128
        }))
Exemplo n.º 14
0
    # Notice that trial_max will only work for stochastic policies
    register_env(
        'ic20env', lambda _: SimplifiedIC20Environment(obs_state_processor,
                                                       act_state_processor,
                                                       UnstableReward(),
                                                       trial_max=10))
    ten_gig = 10737418240

    trainer = ApexTrainer(
        env="ic20env",
        config=merge_dicts(
            DEFAULT_CONFIG,
            {
                # -- Rollout-Worker
                'num_gpus': 1,
                'num_workers': 15,
                "num_envs_per_worker": 1,
                "num_cpus_per_worker": 0.7,
                "memory_per_worker": ten_gig,
                "n_step": 3,
                'lr': 0.0005,
            }))

    # Attempt to restore from checkpoint if possible.
    if os.path.exists(CHECKPOINT_FILE):
        checkpoint_path = open(CHECKPOINT_FILE).read()
        print("Restoring from checkpoint path", checkpoint_path)
        trainer.restore(checkpoint_path)

    # Serving and training loop
    while True:
        print(pretty_print(trainer.train()))
Exemplo n.º 15
0
from ray.rllib.agents.ddpg.ddpg_policy import DDPGTFPolicy, SampleBatch
from ray.rllib.agents.ddpg.td3 import TD3Trainer, TD3_DEFAULT_CONFIG
from ray.rllib.optimizers.sync_replay_optimizer import SyncReplayOptimizer, \
    PrioritizedReplayBuffer, MultiAgentBatch, ray_get_and_free, \
    pack_if_needed, \
    DEFAULT_POLICY_ID, get_learner_stats, np
from ray.tune.util import merge_dicts

from toolbox.marl.utils import on_train_result

DISABLE = "disable"
SHARE_SAMPLE = "share_sample"

cetd3_default_config = merge_dicts(
    TD3_DEFAULT_CONFIG,
    dict(mode=SHARE_SAMPLE, callbacks={"on_train_result": on_train_result})
    # dict(learn_with_peers=True, use_joint_dataset=False, mode=REPLAY_VALUES)
)


class SyncReplayOptimizerWithCooperativeExploration(SyncReplayOptimizer):
    def _replay(self):
        samples = super()._replay()

        # Add other's batch here.
        config = self.workers._local_config
        if config["mode"] == SHARE_SAMPLE:
            share_sample = SampleBatch.concat_samples(
                [batch for batch in samples.policy_batches.values()]
            )
            for pid in samples.policy_batches.keys():
Exemplo n.º 16
0
def validate():
    pair = "XRPBTC"
    interval = "1h"
    algo = "ARS"
    features = ["close", "sma15", "sma50"]
    ENV_NAME = "Projectlife-v1"
    register_env(
        ENV_NAME, lambda config: Environment(mode="validate",
                                             pair=pair,
                                             interval=interval,
                                             algo=algo,
                                             data_features=features))
    parser = create_parser()
    args = parser.parse_args()
    config = {}
    config_dir = os.path.dirname(args.checkpoint)
    config_path = os.path.join(config_dir, "params.pkl")
    if not os.path.exists(config_path):
        config_path = os.path.join(config_dir, "../params.pkl")
    if not os.path.exists(config_path):
        if not args.config:
            raise ValueError(
                "Could not find params.pkl in either the checkpoint dir or "
                "its parent directory.")
    else:
        with open(config_path, "rb") as f:
            config = pickle.load(f)
    config = merge_dicts(config, args.config)
    ray.init()
    if algo == "RAINBOW":
        algo = "DQN"
    cls = get_agent_class(algo)
    agent = cls(env=ENV_NAME, config=config)
    agent.restore(args.checkpoint)
    policy_agent_mapping = default_policy_agent_mapping
    if hasattr(agent, "local_evaluator"):
        env = agent.local_evaluator.env
        multiagent = isinstance(env, MultiAgentEnv)
        if agent.local_evaluator.multiagent:
            policy_agent_mapping = agent.config["multiagent"][
                "policy_mapping_fn"]

        policy_map = agent.local_evaluator.policy_map
        state_init = {p: m.get_initial_state() for p, m in policy_map.items()}
        use_lstm = {p: len(s) > 0 for p, s in state_init.items()}
        action_init = {
            p: m.action_space.sample()
            for p, m in policy_map.items()
        }
    else:
        env = gym.make(ENV_NAME)
        multiagent = False
        use_lstm = {DEFAULT_POLICY_ID: False}

    steps = 0
    while steps < (len(env.df) or steps + 1):
        mapping_cache = {}  # in case policy_agent_mapping is stochastic
        obs = env.reset()
        agent_states = DefaultMapping(
            lambda agent_id: state_init[mapping_cache[agent_id]])
        prev_actions = DefaultMapping(
            lambda agent_id: action_init[mapping_cache[agent_id]])
        prev_rewards = collections.defaultdict(lambda: 0.)
        done = False
        reward_total = 0.0
        while not done and steps < (len(env.df) or steps + 1):
            multi_obs = obs if multiagent else {_DUMMY_AGENT_ID: obs}
            action_dict = {}
            for agent_id, a_obs in multi_obs.items():
                if a_obs is not None:
                    policy_id = mapping_cache.setdefault(
                        agent_id, policy_agent_mapping(agent_id))
                    p_use_lstm = use_lstm[policy_id]
                    if p_use_lstm:
                        a_action, p_state, _ = agent.compute_action(
                            a_obs,
                            state=agent_states[agent_id],
                            prev_action=prev_actions[agent_id],
                            prev_reward=prev_rewards[agent_id],
                            policy_id=policy_id)
                        agent_states[agent_id] = p_state
                    else:
                        a_action = agent.compute_action(
                            a_obs,
                            prev_action=prev_actions[agent_id],
                            prev_reward=prev_rewards[agent_id],
                            policy_id=policy_id)
                    action_dict[agent_id] = a_action
                    prev_actions[agent_id] = a_action
            action = action_dict

            action = action if multiagent else action[_DUMMY_AGENT_ID]
            next_obs, reward, done, _ = env.step(action)
            if multiagent:
                for agent_id, r in reward.items():
                    prev_rewards[agent_id] = r
            else:
                prev_rewards[_DUMMY_AGENT_ID] = reward

            if multiagent:
                done = done["__all__"]
                reward_total += sum(reward.values())
            else:
                reward_total += reward
            env.render()
            steps += 1
            obs = next_obs
        print("Episode reward", reward_total)