Example #1
0
 def get_nn(self):
     config = get_PPO_config(1234)
     trainer = ppo.PPOTrainer(config=config)
     trainer.restore(self.nn_path)
     policy = trainer.get_policy()
     sequential_nn = convert_ray_policy_to_sequential(policy).cpu()
     layers = []
     for l in sequential_nn:
         layers.append(l)
     nn = torch.nn.Sequential(*layers)
     return nn
    def get_nn(self):
        config = get_PPO_config(1234)
        trainer = ppo.PPOTrainer(config=config)
        trainer.restore(self.nn_path)

        policy = trainer.get_policy()
        # sequential_nn = convert_ray_simple_policy_to_sequential(policy).cpu()
        sequential_nn = convert_ray_policy_to_sequential(policy).cpu()
        # l0 = torch.nn.Linear(5, 3, bias=False)
        # l0.weight = torch.nn.Parameter(torch.tensor([[0, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1]], dtype=torch.float32))
        layers = []
        for l in sequential_nn:
            layers.append(l)
        nn = torch.nn.Sequential(*layers)
        # ray.shutdown()
        return nn
Example #3
0
 def get_nn(self):
     config = get_PPO_config(1234, use_gpu=0)
     trainer = ppo.PPOTrainer(config=config)
     trainer.restore(self.nn_path)
     policy = trainer.get_policy()
     sequential_nn = convert_ray_policy_to_sequential(policy).cpu()
     # l0 = torch.nn.Linear(6, 2, bias=False)
     # l0.weight = torch.nn.Parameter(torch.tensor([[0, 0, 1, -1, 0, 0], [1, -1, 0, 0, 0, 0]], dtype=torch.float32))
     # layers = [l0]
     # for l in sequential_nn:
     #     layers.append(l)
     #
     # nn = torch.nn.Sequential(*layers)
     nn = sequential_nn
     # ray.shutdown()
     return nn
Example #4
0
def get_pendulum_ppo_agent():
    config = {
        "env": PendulumEnv,  #
        "model": {
            "fcnet_hiddens": [64, 64],
            "fcnet_activation": "relu"
        },  # model config
        "vf_share_layers": False,  # try different lrs
        "num_workers": 8,  # parallelism
        "num_envs_per_worker": 5,
        "train_batch_size": 2000,
        "framework": "torch",
        "horizon": 1000
    }  # "batch_mode":"complete_episodes"
    trainer = ppo.PPOTrainer(config=config)
    return trainer
Example #5
0
 def get_nn(self):
     pickled_path = self.nn_path + ".pickle"
     if os.path.exists(pickled_path):
         nn = torch.load(pickled_path, map_location=torch.device('cpu'))
         return nn
     config = get_PPO_config(1234, use_gpu=0)
     trainer = ppo.PPOTrainer(config=config)
     trainer.restore(self.nn_path)
     policy = trainer.get_policy()
     sequential_nn = convert_ray_policy_to_sequential(policy).cpu()
     layers = []
     for l in sequential_nn:
         layers.append(l)
     nn = torch.nn.Sequential(*layers)
     torch.save(nn, pickled_path)
     return nn
Example #6
0
def get_car_ppo_agent():
    config = {
        "env": StoppingCar,  #
        "model": {
            "fcnet_hiddens": [20, 20, 20, 20],
            "fcnet_activation": "relu"
        },  # model config,"custom_model": "my_model",
        "vf_share_layers": False,  # try different lrs
        "num_workers": 8,  # parallelism
        # "batch_mode": "complete_episodes", "use_gae": False,  #
        "num_envs_per_worker": 5,
        "train_batch_size": 2000,
        "framework": "torch",
        "horizon": 1000
    }
    trainer = ppo.PPOTrainer(config=config)
    return trainer
def go_train(config):
    trainer = ppo.PPOTrainer(config=config, env="continuousDoubleAuction-v0")

    if is_restore == True:
        trainer.restore(restore_path)

    g_store = ray.util.get_actor("g_store")
    result = None
    for i in range(num_iters):
        result = trainer.train()
        print(pretty_print(result))  # includes result["custom_metrics"]
        print("training loop = {} of {}".format(i + 1, num_iters))
        print("eps sampled so far {}".format(
            ray.get(g_store.get_eps_counter.remote())))

    print("result['experiment_id']", result["experiment_id"])

    return result
Example #8
0
def train_ppo(n_iterations):

    policy_map = policy_mapping_global.copy()
    ext_conf = ppo.DEFAULT_CONFIG.copy()
    ext_conf.update({
        "num_workers": 16,
        "num_gpus": 1,
        "vf_share_layers": True,
        "vf_loss_coeff": 20.00,
        "vf_clip_param": 200.0,
        "lr": 2e-4,
        "multiagent": {
            "policies": filter_keys(policies,
                                    set(policy_mapping_global.values())),
            "policy_mapping_fn": create_policy_mapping_fn(policy_map),
            "policies_to_train": ['ppo_producer', 'ppo_consumer']
        }
    })

    print(
        f"Environment: action space producer {env.action_space_producer}, action space consumer {env.action_space_consumer}, observation space {env.observation_space}"
    )

    ppo_trainer = ppo.PPOTrainer(env=wsr.WorldOfSupplyEnv,
                                 config=dict(ext_conf, **base_trainer_config))

    training_start_time = time.process_time()
    for i in range(n_iterations):
        print(f"\n== Iteration {i} ==")
        update_policy_map(policy_map, i, n_iterations)
        print(f"- policy map: {policy_map}")

        ppo_trainer.workers.foreach_worker(lambda ev: ev.foreach_env(
            lambda env: env.set_iteration(i, n_iterations)))

        t = time.process_time()
        result = ppo_trainer.train()
        print(f"Iteration {i} took [{(time.process_time() - t):.2f}] seconds")
        print_training_results(result)
        print(
            f"Training ETA: [{(time.process_time() - training_start_time)*(n_iterations/(i+1)-1)/60/60:.2f}] hours to go"
        )

    return ppo_trainer
Example #9
0
 def get_nn(self):
     pickled_path = self.nn_path + ".pickle"
     if os.path.exists(pickled_path):
         nn = torch.load(pickled_path, map_location=torch.device('cpu'))
         return nn
     config = get_PPO_config(1234, 0)
     trainer = ppo.PPOTrainer(config=config)
     trainer.restore(self.nn_path)
     policy = trainer.get_policy()
     sequential_nn = convert_ray_policy_to_sequential(policy).cpu()
     # l0 = torch.nn.Linear(6, 2, bias=False)
     # l0.weight = torch.nn.Parameter(torch.tensor([[0, 0, 1, -1, 0, 0], [1, -1, 0, 0, 0, 0]], dtype=torch.float32))
     # layers = [l0]
     # for l in sequential_nn:
     #     layers.append(l)
     #
     # nn = torch.nn.Sequential(*layers)
     nn = sequential_nn
     torch.save(nn, pickled_path)
     # ray.shutdown()
     return nn
Example #10
0
import numpy as np
import ray
import torch.nn
from ray.rllib.agents.ppo import ppo

from environment.collision_avoidance import ColAvoidEnvDiscrete
from training.ppo.tune.tune_train_PPO_collision_avoidance import get_PPO_config
from training.ray_utils import convert_ray_policy_to_sequential

ray.init()
# register_env("fishing", env_creator)
config = get_PPO_config(1234)
trainer = ppo.PPOTrainer(config=config)
# trainer.restore("/home/edoardo/ray_results/tune_PPO_lunar_hover/PPO_LunarHover_7ba4e_00000_0_2021-04-02_19-01-43/checkpoint_990/checkpoint-990")
trainer.restore("/home/edoardo/ray_results/tune_PPO_collision_avoidance/PPO_ColAvoidEnvDiscrete_12944_00000_0_2021-04-26_15-24-12/checkpoint_160/checkpoint-160")

policy = trainer.get_policy()
# sequential_nn = convert_ray_simple_policy_to_sequential(policy).cpu()
sequential_nn = convert_ray_policy_to_sequential(policy).cpu()
# l0 = torch.nn.Linear(4, 2, bias=False)
# l0.weight = torch.nn.Parameter(torch.tensor([[0, 0, 1, 0], [0, 0, 0, 1]], dtype=torch.float32))
# layers = [l0]
# for l in sequential_nn:
#     layers.append(l)
# nn = torch.nn.Sequential(*layers)
nn = sequential_nn
env = ColAvoidEnvDiscrete()
# env.render()
plot_index = 0
position_list = []
# env.render()
Example #11
0
from ray.rllib.models import ModelCatalog
from tqdm import tqdm

from aie import plotting
from aie.aie_env import AIEEnv
from rl.conf import BASE_PPO_CONF, OUT_DIR
from rl.models.tf.fcnet_lstm import RNNModel

# %%
ray.init()
ModelCatalog.register_custom_model("my_model", RNNModel)

# %%
trainer = ppo.PPOTrainer(config={
    **BASE_PPO_CONF,
    "num_gpus": 1,
    "num_workers": 0,
})

ckpt_path = OUT_DIR / 'PPO_AIEEnv_2021-02-20_21-39-20554z54an/checkpoint_14/checkpoint-14'

trainer.restore(str(ckpt_path))

# %%
env = AIEEnv({}, force_dense_logging=True)
obs = env.reset()
hidden_states = {
    k: [
        np.zeros(128, np.float32),
        np.zeros(128, np.float32),
    ]
Example #12
0
        'eager': True
    }
    config.update(ppo_conf)
    config['log_level'] = 'INFO'

    config["num_gpus"] = args.ngpu
    config["num_workers"] = args.ncpu
    config['num_envs_per_worker'] = 1
    config['model'] = {
        "custom_model": "rnn",
        "max_seq_len": 16,
    }
    config['env_config'] = envconf
    config['eager'] = True

    trainer = ppo.PPOTrainer(config=config, env='lactamase_docking')
    policy = trainer.get_policy()
    print(policy.model.base_model.summary())

    config['env'] = 'lactamase_docking'

    for i in range(250):
        result = trainer.train()

        if i % 1 == 0:
            print(pretty_print(result))

        if i % 25 == 0:
            checkpoint = trainer.save()
            print("checkpoint saved at", checkpoint)
def train_ppo(args):
    env_config_for_rendering.update({'init': args.init})
    ext_conf = ppo.DEFAULT_CONFIG.copy()
    ext_conf.update({
        "env": InventoryManageEnv,
        "framework": "torch",
        "num_workers": 4,
        "vf_share_layers": True,
        "vf_loss_coeff": 1.00,
        # estimated max value of vf, used to normalization
        "vf_clip_param": 100.0,
        "clip_param": 0.2,
        "use_critic": True,
        "use_gae": True,
        "lambda": 1.0,
        "gamma": 0.99,
        'env_config': env_config_for_rendering.copy(),
        # Number of steps after which the episode is forced to terminate. Defaults
        # to `env.spec.max_episode_steps` (if present) for Gym envs.
        "horizon": args.episod,
        # Calculate rewards but don't reset the environment when the horizon is
        # hit. This allows value estimation and RNN state to span across logical
        # episodes denoted by horizon. This only has an effect if horizon != inf.
        "soft_horizon": False,
        # Minimum env steps to optimize for per train call. This value does
        # not affect learning, only the length of train iterations.
        'timesteps_per_iteration': 1000,
        'batch_mode': 'complete_episodes',
        # Size of batches collected from each worker
        "rollout_fragment_length": args.rollout_fragment_length,
        # Number of timesteps collected for each SGD round. This defines the size
        # of each SGD epoch.
        "train_batch_size": args.rollout_fragment_length * args.batch_size,
        # Whether to shuffle sequences in the batch when training (recommended).
        "shuffle_sequences": True,
        # Total SGD batch size across all devices for SGD. This defines the
        # minibatch size within each epoch.
        "sgd_minibatch_size":
        args.rollout_fragment_length * args.min_batch_size,
        # Number of SGD iterations in each outer loop (i.e., number of epochs to
        # execute per train batch).
        "num_sgd_iter": 50,
        "lr": 1e-4,
        "_fake_gpus": True,
        "num_gpus": 0,
        "explore": True,
        "exploration_config": {
            "type": StochasticSampling,
            "random_timesteps":
            0,  # args.rollout_fragment_length*args.batch_size*args.stop_iters // 2,
        },
        "multiagent": {
            "policies": policies,
            "policy_mapping_fn": policy_map_fn,
            "policies_to_train": ['ppo_store_consumer']
        }
    })

    print(
        f"Environment: action space producer {env.action_space_producer}, action space consumer {env.action_space_consumer}, observation space {env.observation_space}",
        flush=True)

    if (args.is_pretrained):
        ext_conf.update({
            'num_workers': 0  #, 'episod_duration':args.episod
        })
        ppo_trainer = ppo.PPOTrainer(env=InventoryManageEnv, config=ext_conf)
        env.env_config.update({'episod_duration': args.episod})
        ppo_trainer.restore(args.premodel)
        visualization(InventoryManageEnv(env_config.copy()),
                      get_policy(env, ppo_trainer), 1, args.run_name)
        return ppo_trainer

    # ppo_trainer.restore('/root/ray_results/PPO_InventoryManageEnv_2020-11-02_18-25-55cle_glgg/checkpoint_20/checkpoint-20')

    # stop = {
    #     "training_iteration": args.stop_iters,
    #     "timesteps_total": args.stop_timesteps,
    #     "episode_reward_min": args.stop_reward,
    # }

    # analysis = tune.run(args.run, config=ext_conf, stop=stop, mode='max', checkpoint_freq=1, verbose=1)
    # checkpoints = analysis.get_trial_checkpoints_paths(
    #                         trial=analysis.get_best_trial("episode_reward_max"),
    #                         metric="episode_reward_max")
    # ppo_trainer.restore(checkpoints[0][0])

    ext_conf['env_config'].update({
        'gamma': ext_conf['gamma'],
        'training': True,
        'policies': None
    })

    ppo_trainer = ppo.PPOTrainer(env=InventoryManageEnv, config=ext_conf)
    max_mean_reward = -100

    ppo_trainer.workers.foreach_worker(lambda ev: ev.foreach_env(
        lambda env: env.set_policies(get_policy(env, ev))))

    for i in range(args.stop_iters):
        print("== Iteration", i, "==", flush=True)

        ppo_trainer.workers.foreach_worker(lambda ev: ev.foreach_env(
            lambda env: env.set_iteration(i, args.stop_iters)))
        result = ppo_trainer.train()
        print_training_results(result)
        now_mean_reward = result['policy_reward_mean']['ppo_store_consumer']

        if (
                i + 1
        ) % args.visualization_frequency == 0 or now_mean_reward > max_mean_reward:
            max_mean_reward = max(max_mean_reward, now_mean_reward)
            checkpoint = ppo_trainer.save()
            print("checkpoint saved at", checkpoint, flush=True)
            visualization(InventoryManageEnv(env_config.copy()),
                          get_policy(env, ppo_trainer), i, args.run_name)
            # exit(0)

    return ppo_trainer
Example #14
0
def _main():
    """ Training example """

    # Initialize RAY.
    ray.tune.registry.register_env('test_env', marlenvironment.env_creator)
    ray.init()

    # Algorithm.
    policy_class = ppo.PPOTFPolicy
    # https://github.com/ray-project/ray/blob/releases/0.8.3/rllib/agents/trainer.py#L41
    # https://github.com/ray-project/ray/blob/releases/0.8.3/rllib/agents/ppo/ppo.py#L15
    policy_conf = ppo.DEFAULT_CONFIG
    policy_conf['batch_mode'] = 'complete_episodes'
    policy_conf['log_level'] = 'WARN'
    policy_conf['min_iter_time_s'] = 5
    policy_conf['num_workers'] = 2
    policy_conf['rollout_fragment_length'] = 1
    policy_conf['seed'] = 42
    policy_conf['sgd_minibatch_size'] = 1
    policy_conf['simple_optimizer'] = True
    policy_conf['train_batch_size'] = 1

    # Load default Scenario configuration for the LEARNING ENVIRONMENT
    scenario_config = deepcopy(marlenvironment.DEFAULT_SCENARIO_CONFING)
    scenario_config['seed'] = 42
    scenario_config['log_level'] = 'INFO'
    scenario_config['sumo_config']['sumo_connector'] = 'libsumo'
    scenario_config['sumo_config'][
        'sumo_cfg'] = '{}/scenario/sumo.cfg.xml'.format(
            pathlib.Path(__file__).parent.absolute())
    scenario_config['sumo_config']['sumo_params'] = [
        '--collision.action', 'warn'
    ]
    scenario_config['sumo_config']['trace_file'] = True
    scenario_config['sumo_config']['end_of_sim'] = 3600  # [s]
    scenario_config['sumo_config'][
        'update_freq'] = 10  # number of traci.simulationStep()
    # for each learning step.
    scenario_config['sumo_config']['log_level'] = 'INFO'
    logger.info('Scenario Configuration: \n %s', pformat(scenario_config))

    # Associate the agents with their configuration.
    agent_init = {
        'agent_0': deepcopy(marlenvironment.DEFAULT_AGENT_CONFING),
        'agent_1': deepcopy(marlenvironment.DEFAULT_AGENT_CONFING),
    }
    logger.info('Agents Configuration: \n %s', pformat(agent_init))

    ## MARL Environment Init
    env_config = {
        'agent_init': agent_init,
        'scenario_config': scenario_config,
    }
    marl_env = marlenvironment.SUMOTestMultiAgentEnv(env_config)

    # Config for the PPO trainer from the MARLEnv
    policies = {}
    for agent in marl_env.get_agents():
        agent_policy_params = {}
        policies[agent] = (policy_class, marl_env.get_obs_space(agent),
                           marl_env.get_action_space(agent),
                           agent_policy_params)
    policy_conf['multiagent']['policies'] = policies
    policy_conf['multiagent']['policy_mapping_fn'] = lambda agent_id: agent_id
    policy_conf['multiagent']['policies_to_train'] = ['ppo_policy']
    policy_conf['env_config'] = env_config

    logger.info('PPO Configuration: \n %s', pformat(policy_conf))
    trainer = ppo.PPOTrainer(env='test_env', config=policy_conf)

    # Single training iteration, just for testing.
    try:
        result = trainer.train()
        print('Results: \n {}'.format(pretty_print(result)))
    except Exception:
        EXC_TYPE, EXC_VALUE, EXC_TRACEBACK = sys.exc_info()
        traceback.print_exception(EXC_TYPE,
                                  EXC_VALUE,
                                  EXC_TRACEBACK,
                                  file=sys.stdout)
    finally:
        ray.shutdown()
Example #15
0
def train_ppo(args, env, knapsack_config, workdir, n_iterations):
    ext_conf = ppo.DEFAULT_CONFIG.copy()
    ext_conf.update({
            "num_workers": 2,
            "num_cpus_per_worker": 1,
            "vf_share_layers": True,
            "vf_loss_coeff": 1.0,      
            "vf_clip_param": 100.0,
            "use_critic": True,
            "use_gae": True,
            "framework": "torch",
            "lambda": 1.0,
            "gamma": 1.0,
            'env_config': knapsack_config,
            'timesteps_per_iteration': knapsack_config['episode_len'],
            'batch_mode': 'complete_episodes',
            # Size of batches collected from each worker
            "rollout_fragment_length": args.rollout,
            # Number of timesteps collected for each SGD round. This defines the size
            # of each SGD epoch.
            "train_batch_size": args.batch_size*args.rollout,
            # Total SGD batch size across all devices for SGD. This defines the
            # minibatch size within each epoch.
            "sgd_minibatch_size": args.min_batch_size*args.rollout,
            # Number of SGD iterations in each outer loop (i.e., number of epochs to
            # execute per train batch).
            "num_sgd_iter": 100,
            "shuffle_sequences": True,
            "lr": 1e-4,
            "_fake_gpus": True,
            "num_gpus": 0,
            "num_gpus_per_worker": 0,
            "model": {"custom_model": "knapsack_model"},
            "explore": True,
            # "exploration_config": {
            #     # The Exploration class to use.
            #     "type": "EpsilonGreedy",
            #     # Config for the Exploration class' constructor:
            #     "initial_epsilon": 1.0,
            #     "final_epsilon": 0.02,
            #     "epsilon_timesteps": args.rollout*args.batch_size*args.iters // 3,  # Timesteps over which to anneal epsilon.
            # },
            "exploration_config": {
                "type": StochasticSampling,
                "random_timesteps": args.rollout*args.batch_size*args.iters // 4,
            },
        })
    
    print(f"Environment: action space {env.action_space}, observation space {env.observation_space}")
    ppo_trainer = ppo.PPOTrainer(
        env = KnapsackEnv,
        config = ext_conf)
    
    # ppo_trainer.restore('/root/ray_results/PPO_CVRPEnv_2020-12-29_11-50-29uylrljyr/checkpoint_100/checkpoint-100')
    
    mean_cost_list = []
    total_cost_list = []
    for i in range(n_iterations):
        print("== Iteration", i, "==")
        trainer_result = ppo_trainer.train()
        print_training_results(trainer_result)
        # cost = env.total_cost - (trainer_result['episode_reward_mean']*env.total_cost) / trainer_result['episode_len_mean']
        # cost = (1.0 - trainer_result['episode_reward_mean']/trainer_result['episode_len_mean']) * env.max_cost * env.num_nodes
        cost = trainer_result['episode_reward_mean']
        mean_cost_list.append(cost)
        print('cost: ', cost)
        if (i+1) % 5 == 0:
            checkpoint = ppo_trainer.save()
            print("checkpoint saved at", checkpoint)
            _total_value = draw_route(args, ppo_trainer, env, mean_cost_list, workdir)
            total_cost_list.append(_total_value)
    list_to_figure([total_cost_list], ['total_cost'], 'total_cost', f'{workdir}/rl_knapsack_total_cost_{args.problem}.png')
    return ppo_trainer, mean_cost_list
def create_ppo_trainer(echelon):
    policy_map_fn = (lambda x: echelon_policy_map_fn(echelon, x))
    policies_to_train = (['ppo_store_consumer'] if echelon == env.world.total_echelon -1 else ['ppo_warehouse_consumer'])
    ext_conf = ppo.DEFAULT_CONFIG.copy()
    ext_conf.update({
            "env": InventoryManageEnv,
            "framework": "torch",
            "num_workers": 2,
            "vf_share_layers": True,
            "vf_loss_coeff": 1.00,   
            # estimated max value of vf, used to normalization   
            "vf_clip_param": 10.0,
            "clip_param": 0.1, 
            "use_critic": True,
            "use_gae": True,
            "lambda": 1.0,
            "gamma": 0.9,
            'env_config': env_config_for_rendering,
            # Number of steps after which the episode is forced to terminate. Defaults
            # to `env.spec.max_episode_steps` (if present) for Gym envs.
            "horizon": args.episod,
            # Calculate rewards but don't reset the environment when the horizon is
            # hit. This allows value estimation and RNN state to span across logical
            # episodes denoted by horizon. This only has an effect if horizon != inf.
            "soft_horizon": False,
            # Minimum env steps to optimize for per train call. This value does
            # not affect learning, only the length of train iterations.
            'timesteps_per_iteration': 1000,
            'batch_mode': 'complete_episodes',
            # Size of batches collected from each worker
            "rollout_fragment_length": args.rollout_fragment_length,
            # Number of timesteps collected for each SGD round. This defines the size
            # of each SGD epoch.
            "train_batch_size": args.rollout_fragment_length*args.batch_size,
            # Whether to shuffle sequences in the batch when training (recommended).
            "shuffle_sequences": True,
            # Total SGD batch size across all devices for SGD. This defines the
            # minibatch size within each epoch.
            "sgd_minibatch_size": args.rollout_fragment_length*args.min_batch_size,
            # Number of SGD iterations in each outer loop (i.e., number of epochs to
            # execute per train batch).
            "num_sgd_iter": 50,
            "lr": 1e-4,
            "_fake_gpus": True,
            "num_gpus": 0,
            "explore": True,
            "exploration_config": {
                "type": StochasticSampling,
                "random_timesteps": 10000, # args.rollout_fragment_length*args.batch_size*args.stop_iters // 2,
            },
            "multiagent": {
                "policies": filter_keys(policies, ['baseline_producer', 'baseline_consumer', 'ppo_store_consumer', 'ppo_warehouse_consumer']),
                "policy_mapping_fn": policy_map_fn,
                "policies_to_train": policies_to_train
            }
        })

    print(f"Environment: action space producer {env.action_space_producer}, action space consumer {env.action_space_consumer}, observation space {env.observation_space}")
    ppo_trainer = ppo.PPOTrainer(
        env = InventoryManageEnv,
        config = ext_conf)
    return ppo_trainer
def _main():
    """ Testing loop """
    # Args
    logger.info('Arguments: %s', str(ARGS))

    # Results
    metrics_dir, checkpoint_dir, debug_dir = results_handler(ARGS)

    # Algorithm.
    policy_class = None
    policy_conf = None
    policy_params = None
    checkout_steps = 1  # save each episode
    if ARGS.algo == 'PPO':
        policy_class = ppo.PPOTFPolicy
        policy_conf = {
            **ppo.DEFAULT_CONFIG,
            **ppo_conf.ppo_conf(checkout_steps, debug_dir)
        }
        policy_params = {}
    elif ARGS.algo == 'A3C':
        policy_class = a3c.A3CTFPolicy
        policy_conf = {
            **a3c.DEFAULT_CONFIG,
            **a3c_conf.a3c_conf(checkout_steps, debug_dir)
        }
        policy_params = {}
    elif ARGS.algo == 'QLSA':
        policy_class = QLStandAlone.QLearningTestingPolicy
        policy_conf = qlearning_conf.qlearning_conf(checkout_steps, debug_dir)
        policy_params = {}
    else:
        raise Exception('Unknown algorithm %s' % ARGS.algo)

    # Load default Scenario configuration
    experiment_config = load_json_file(ARGS.config)

    # Initialize the simulation.
    ray.init(memory=52428800, object_store_memory=78643200)  ## minimum values

    # Associate the agents with something
    agent_init = load_json_file(experiment_config['agents_init_file'])
    env_config = {
        'metrics_dir': metrics_dir,
        'checkpoint_dir': checkpoint_dir,
        'agent_init': agent_init,
        'scenario_config': experiment_config['marl_env_config'],
    }
    marl_env = None
    if ARGS.env == 'MARL':
        ray.tune.registry.register_env('marl_env', marlenvironment.env_creator)
        marl_env = marlenvironment.PersuasiveMultiAgentEnv(env_config)
    elif ARGS.env == 'MARLCoop':
        ray.tune.registry.register_env('marl_env',
                                       marlenvironmentagentscoop.env_creator)
        marl_env = marlenvironmentagentscoop.AgentsCoopMultiAgentEnv(
            env_config)
    elif ARGS.env == 'LateMARL':
        ray.tune.registry.register_env('marl_env',
                                       marlenvironmentlatereward.env_creator)
        marl_env = marlenvironmentlatereward.LateRewardMultiAgentEnv(
            env_config)
    else:
        raise Exception('Unknown environment %s' % ARGS.env)

    # Gen config
    policies = {}
    for agent in marl_env.get_agents():
        agent_policy_params = deepcopy(policy_params)
        from_val, to_val = agent_init[agent]['init']
        agent_policy_params['init'] = lambda: random.randint(from_val, to_val)
        agent_policy_params['actions'] = marl_env.get_set_of_actions(agent)
        agent_policy_params['seed'] = agent_init[agent]['seed']
        policies[agent] = (policy_class, marl_env.get_obs_space(agent),
                           marl_env.get_action_space(agent),
                           agent_policy_params)
    policy_conf['multiagent'] = {
        'policies': policies,
        'policy_mapping_fn': lambda agent_id: agent_id,
    }
    policy_conf['env_config'] = env_config
    logger.info('Configuration: \n%s', pformat(policy_conf))

    def default_logger_creator(config):
        """
            Creates a Unified logger with a default logdir prefix
            containing the agent name and the env id
        """
        log_dir = os.path.join(os.path.normpath(ARGS.dir), 'logs')
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
        return UnifiedLogger(config,
                             log_dir)  # loggers = None) >> Default loggers

    def dblogger_logger_creator(config):
        """
            Creates a Unified logger with a default logdir prefix
            containing the agent name and the env id
        """
        log_dir = os.path.join(os.path.normpath(ARGS.dir), 'logs')
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
        return UnifiedLogger(config, log_dir, loggers=[DBLogger])

    trainer = None
    if ARGS.algo == 'PPO':
        trainer = ppo.PPOTrainer(env='marl_env',
                                 config=policy_conf,
                                 logger_creator=default_logger_creator)
    elif ARGS.algo == 'A3C':
        trainer = a3c.A3CTrainer(env='marl_env',
                                 config=policy_conf,
                                 logger_creator=default_logger_creator)
    elif ARGS.algo == 'QLSA':
        trainer = QLStandAlone.QLearningTester(
            env='marl_env',
            config=policy_conf,
            logger_creator=dblogger_logger_creator)
    else:
        raise Exception('Unknown algorithm %s' % ARGS.algo)

    target_checkpoint = get_target_checkpoint(ARGS.target_checkpoint)
    if target_checkpoint is not None:
        logger.info('[Trainer:main] Restoring checkpoint: %s',
                    target_checkpoint)
        trainer.restore(target_checkpoint)
    else:
        raise Exception('Checkpoint {} does not exist.'.format(
            ARGS.target_checkpoint))

    steps = 0
    final_result = None

    for _ in range(ARGS.testing_episodes):
        # Do one step.
        result = trainer.train()
        checkpoint = trainer.save(checkpoint_dir)
        logger.info('[Trainer:main] Checkpoint saved in %s', checkpoint)
        # steps += result['info']['num_steps_trained']
        steps += result[
            'timesteps_this_iter']  # is related to 'timesteps_total' that is the same
        # as result['info']['num_steps_sampled']
        final_result = result

    print_selected_results(final_result, SELECTION)