Esempio n. 1
0
from matplotlib import pyplot as plt
import numpy as np
import torch
import argparse
# Logger Params
parser = argparse.ArgumentParser()
parser.add_argument('--exp_name', type=str, default='zero_sum')
parser.add_argument('--log_dir', type=str, default='PRGGaussiank1oace')
parser.add_argument('--epoch', type=int, default=None)
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()

pre_path = './Data/' + args.exp_name + '/' + args.log_dir

from differential_game import DifferentialGame
env = DifferentialGame(game_name=args.exp_name)

a1s = np.linspace(-1, 1, 100)
a2s = np.linspace(-1, 1, 100)
c1s = []
c2s = []

if args.epoch or (args.epoch == 0):
    d_path = pre_path + '/' + 'seed' + str(args.seed) + '/itr_' + str(
        args.epoch) + '.pkl'
    plot_file = pre_path + '/' + 'seed' + str(args.seed) + '/cactor_ep' + str(
        args.epoch)
else:
    d_path = pre_path + '/' + 'seed' + str(args.seed) + '/params.pkl'
    plot_file = pre_path + '/' + 'seed' + str(args.seed) + '/cactor'
data = torch.load(d_path, map_location='cpu')
Esempio n. 2
0
def experiment(variant):
    num_agent = variant['num_agent']
    from differential_game import DifferentialGame
    expl_env = DifferentialGame(game_name=args.exp_name)
    eval_env = DifferentialGame(game_name=args.exp_name)
    obs_dim = eval_env.observation_space.low.size
    action_dim = eval_env.action_space.low.size

    qf1_n, qf2_n, cactor_n, policy_n, target_qf1_n, target_qf2_n, target_policy_n, expl_policy_n, eval_policy_n = \
        [], [], [], [], [], [], [], [], []
    for i in range(num_agent):
        qf1 = FlattenMlp(input_size=(obs_dim * num_agent +
                                     action_dim * num_agent),
                         output_size=1,
                         **variant['qf_kwargs'])
        qf2 = FlattenMlp(input_size=(obs_dim * num_agent +
                                     action_dim * num_agent),
                         output_size=1,
                         **variant['qf_kwargs'])
        cactor = TanhMlpPolicy(input_size=(obs_dim * num_agent + action_dim *
                                           (num_agent - 1)),
                               output_size=action_dim,
                               **variant['cactor_kwargs'])
        policy = TanhMlpPolicy(input_size=obs_dim,
                               output_size=action_dim,
                               **variant['policy_kwargs'])
        target_qf1 = copy.deepcopy(qf1)
        target_qf2 = copy.deepcopy(qf2)
        target_policy = copy.deepcopy(policy)
        eval_policy = policy
        expl_policy = PolicyWrappedWithExplorationStrategy(
            exploration_strategy=OUStrategy(
                action_space=expl_env.action_space),
            policy=policy,
        )
        qf1_n.append(qf1)
        qf2_n.append(qf2)
        cactor_n.append(cactor)
        policy_n.append(policy)
        target_qf1_n.append(target_qf1)
        target_qf2_n.append(target_qf2)
        target_policy_n.append(target_policy)
        expl_policy_n.append(expl_policy)
        eval_policy_n.append(eval_policy)

    eval_path_collector = MAMdpPathCollector(eval_env, eval_policy_n)
    expl_path_collector = MAMdpPathCollector(expl_env, expl_policy_n)
    replay_buffer = MAEnvReplayBuffer(variant['replay_buffer_size'],
                                      expl_env,
                                      num_agent=num_agent)
    trainer = PRGTrainer(env=expl_env,
                         qf1_n=qf1_n,
                         target_qf1_n=target_qf1_n,
                         qf2_n=qf2_n,
                         target_qf2_n=target_qf2_n,
                         policy_n=policy_n,
                         target_policy_n=target_policy_n,
                         cactor_n=cactor_n,
                         **variant['trainer_kwargs'])
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        log_path_function=get_generic_ma_path_information,
        **variant['algorithm_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
Esempio n. 3
0
def experiment(variant):
    num_agent = variant['num_agent']
    from differential_game import DifferentialGame
    expl_env = DifferentialGame(game_name=args.exp_name)
    eval_env = DifferentialGame(game_name=args.exp_name)
    obs_dim = eval_env.observation_space.low.size
    action_dim = eval_env.action_space.low.size

    policy_n, eval_policy_n, expl_policy_n, qf1_n, target_qf1_n, qf2_n, target_qf2_n = \
        [], [], [], [], [], [], []
    for i in range(num_agent):
        from rlkit.torch.layers import SplitLayer, ReshapeLayer
        weight_head = nn.Linear(variant['policy_kwargs']['hidden_dim'],
                                variant['policy_kwargs']['m'])
        mean_head = nn.Sequential(
            nn.Linear(variant['policy_kwargs']['hidden_dim'],
                      action_dim * variant['policy_kwargs']['m']),
            ReshapeLayer(shape=[variant['policy_kwargs']['m'], action_dim]))
        logstd_head = nn.Sequential(
            nn.Linear(variant['policy_kwargs']['hidden_dim'],
                      action_dim * variant['policy_kwargs']['m']),
            ReshapeLayer(shape=[variant['policy_kwargs']['m'], action_dim]))
        policy = nn.Sequential(
            nn.Linear(obs_dim, variant['policy_kwargs']['hidden_dim']),
            nn.ReLU(),
            nn.Linear(variant['policy_kwargs']['hidden_dim'],
                      variant['policy_kwargs']['hidden_dim']), nn.ReLU(),
            SplitLayer(layers=[weight_head, mean_head, logstd_head]))
        from rlkit.torch.policies.mix_tanh_gaussian_policy import MixTanhGaussianPolicy
        policy = MixTanhGaussianPolicy(module=policy)
        from rlkit.torch.policies.make_deterministic import MakeDeterministic
        eval_policy = MakeDeterministic(policy)
        from rlkit.exploration_strategies.base import PolicyWrappedWithExplorationStrategy
        if variant['random_exploration']:
            from rlkit.exploration_strategies.epsilon_greedy import EpsilonGreedy
            expl_policy = PolicyWrappedWithExplorationStrategy(
                exploration_strategy=EpsilonGreedy(expl_env.action_space,
                                                   prob_random_action=1.0),
                policy=policy,
            )
        else:
            expl_policy = policy
        from rlkit.torch.networks import FlattenMlp
        qf1 = FlattenMlp(
            input_size=(obs_dim * num_agent + action_dim * num_agent),
            output_size=1,
            hidden_sizes=[variant['qf_kwargs']['hidden_dim']] * 2,
        )
        target_qf1 = copy.deepcopy(qf1)
        qf2 = FlattenMlp(
            input_size=(obs_dim * num_agent + action_dim * num_agent),
            output_size=1,
            hidden_sizes=[variant['qf_kwargs']['hidden_dim']] * 2,
        )
        target_qf2 = copy.deepcopy(qf2)
        policy_n.append(policy)
        eval_policy_n.append(eval_policy)
        expl_policy_n.append(expl_policy)
        qf1_n.append(qf1)
        target_qf1_n.append(target_qf1)
        qf2_n.append(qf2)
        target_qf2_n.append(target_qf2)

    from rlkit.samplers.data_collector.ma_path_collector import MAMdpPathCollector
    eval_path_collector = MAMdpPathCollector(eval_env, eval_policy_n)
    expl_path_collector = MAMdpPathCollector(expl_env, expl_policy_n)

    from rlkit.data_management.ma_env_replay_buffer import MAEnvReplayBuffer
    replay_buffer = MAEnvReplayBuffer(variant['replay_buffer_size'],
                                      expl_env,
                                      num_agent=num_agent)

    from rlkit.torch.masac.masac import MASACTrainer
    trainer = MASACTrainer(env=expl_env,
                           qf1_n=qf1_n,
                           target_qf1_n=target_qf1_n,
                           qf2_n=qf2_n,
                           target_qf2_n=target_qf2_n,
                           policy_n=policy_n,
                           **variant['trainer_kwargs'])

    from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        log_path_function=get_generic_ma_path_information,
        **variant['algorithm_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
Esempio n. 4
0
def experiment(variant):
    num_agent = variant['num_agent']
    from differential_game import DifferentialGame
    expl_env = DifferentialGame(game_name=args.exp_name)
    eval_env = DifferentialGame(game_name=args.exp_name)
    obs_dim = eval_env.observation_space.low.size
    action_dim = eval_env.action_space.low.size

    from rlkit.torch.networks.graph_builders import FullGraphBuilder
    graph_builder_1 = FullGraphBuilder(input_node_dim=obs_dim + action_dim,
                                       num_node=num_agent,
                                       contain_self_loop=False)
    from rlkit.torch.networks.gnn_networks import GNNNet
    cg1 = GNNNet(
        graph_builder_1,
        node_dim=variant['graph_kwargs']['hidden_dim'],
        conv_type=variant['graph_kwargs']['conv_type'],
        num_conv_layers=2,
        hidden_activation='relu',
        output_activation='relu',
    )
    target_cg1 = copy.deepcopy(cg1)
    qf1 = nn.Sequential(
        nn.Linear(variant['graph_kwargs']['hidden_dim'] + action_dim,
                  variant['qf_kwargs']['hidden_dim']), nn.ReLU(),
        nn.Linear(variant['qf_kwargs']['hidden_dim'], 1))
    target_qf1 = copy.deepcopy(qf1)

    graph_builder_2 = FullGraphBuilder(input_node_dim=obs_dim + action_dim,
                                       num_node=num_agent,
                                       contain_self_loop=False)
    cg2 = GNNNet(
        graph_builder_2,
        node_dim=variant['graph_kwargs']['hidden_dim'],
        conv_type=variant['graph_kwargs']['conv_type'],
        num_conv_layers=2,
        hidden_activation='relu',
        output_activation='relu',
    )
    target_cg2 = copy.deepcopy(cg2)
    qf2 = nn.Sequential(
        nn.Linear(variant['graph_kwargs']['hidden_dim'] + action_dim,
                  variant['qf_kwargs']['hidden_dim']), nn.ReLU(),
        nn.Linear(variant['qf_kwargs']['hidden_dim'], 1))
    target_qf2 = copy.deepcopy(qf2)

    graph_builder_ca = FullGraphBuilder(input_node_dim=obs_dim + action_dim,
                                        num_node=num_agent,
                                        contain_self_loop=False)
    from rlkit.torch.networks.gnn_networks import GNNNet
    cgca = GNNNet(
        graph_builder_ca,
        node_dim=variant['graph_kwargs']['hidden_dim'],
        conv_type=variant['graph_kwargs']['conv_type'],
        num_conv_layers=2,
        hidden_activation='relu',
        output_activation='relu',
    )
    from rlkit.torch.networks.layers import SplitLayer
    from rlkit.torch.policies.tanh_gaussian_policy import TanhGaussianPolicy
    cactor = nn.Sequential(
        cgca,
        nn.Linear(variant['graph_kwargs']['hidden_dim'],
                  variant['cactor_kwargs']['hidden_dim']), nn.ReLU(),
        SplitLayer(layers=[
            nn.Linear(variant['policy_kwargs']['hidden_dim'], action_dim),
            nn.Linear(variant['policy_kwargs']['hidden_dim'], action_dim)
        ]))
    cactor = TanhGaussianPolicy(module=cactor)

    policy_n, expl_policy_n, eval_policy_n = [], [], []
    for i in range(num_agent):
        policy = nn.Sequential(
            nn.Linear(obs_dim, variant['policy_kwargs']['hidden_dim']),
            nn.ReLU(),
            nn.Linear(variant['policy_kwargs']['hidden_dim'],
                      variant['policy_kwargs']['hidden_dim']), nn.ReLU(),
            SplitLayer(layers=[
                nn.Linear(variant['policy_kwargs']['hidden_dim'], action_dim),
                nn.Linear(variant['policy_kwargs']['hidden_dim'], action_dim)
            ]))
        policy = TanhGaussianPolicy(module=policy)
        from rlkit.torch.policies.make_deterministic import MakeDeterministic
        eval_policy = MakeDeterministic(policy)
        from rlkit.exploration_strategies.base import PolicyWrappedWithExplorationStrategy
        if variant['random_exploration']:
            from rlkit.exploration_strategies.epsilon_greedy import EpsilonGreedy
            expl_policy = PolicyWrappedWithExplorationStrategy(
                exploration_strategy=EpsilonGreedy(expl_env.action_space,
                                                   prob_random_action=1.0),
                policy=policy,
            )
        else:
            expl_policy = policy

        policy_n.append(policy)
        expl_policy_n.append(expl_policy)
        eval_policy_n.append(eval_policy)

    from rlkit.samplers.data_collector.ma_path_collector import MAMdpPathCollector
    eval_path_collector = MAMdpPathCollector(eval_env, eval_policy_n)
    expl_path_collector = MAMdpPathCollector(expl_env, expl_policy_n)

    from rlkit.data_management.ma_env_replay_buffer import MAEnvReplayBuffer
    replay_buffer = MAEnvReplayBuffer(variant['replay_buffer_size'],
                                      expl_env,
                                      num_agent=num_agent)

    from rlkit.torch.r2g.r2g_gnn2 import R2GGNNTrainer
    trainer = R2GGNNTrainer(env=expl_env,
                            cg1=cg1,
                            target_cg1=target_cg1,
                            qf1=qf1,
                            target_qf1=target_qf1,
                            cg2=cg2,
                            target_cg2=target_cg2,
                            qf2=qf2,
                            target_qf2=target_qf2,
                            cactor=cactor,
                            policy_n=policy_n,
                            **variant['trainer_kwargs'])

    from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        log_path_function=get_generic_ma_path_information,
        **variant['algorithm_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
Esempio n. 5
0
def experiment(variant):
    num_agent = variant['num_agent']
    from differential_game import DifferentialGame
    expl_env = DifferentialGame(game_name=args.exp_name)
    eval_env = DifferentialGame(game_name=args.exp_name)
    obs_dim = eval_env.observation_space.low.size
    action_dim = eval_env.action_space.low.size

    from rlkit.torch.networks.graph_builders import FullGraphBuilder
    graph_builder_1 = FullGraphBuilder(
        input_node_dim=obs_dim + action_dim,
        num_node=num_agent,
        batch_size=variant['algorithm_kwargs']['batch_size'],
        contain_self_loop=False)
    from rlkit.torch.networks.graph_context_network import GraphContextNet
    cg1 = GraphContextNet(
        graph_builder_1,
        obs_dim,
        action_dim,
        use_attention=variant['graph_kwargs']['use_attention'],
        num_layer=variant['graph_kwargs']['num_layer'],
        node_dim=variant['graph_kwargs']['hidden_dim'],
        output_activation='relu',
    )
    target_cg1 = copy.deepcopy(cg1)

    graph_builder_2 = FullGraphBuilder(
        input_node_dim=obs_dim + action_dim,
        num_node=num_agent,
        batch_size=variant['algorithm_kwargs']['batch_size'],
        contain_self_loop=False)
    cg2 = GraphContextNet(
        graph_builder_2,
        obs_dim,
        action_dim,
        use_attention=variant['graph_kwargs']['use_attention'],
        num_layer=variant['graph_kwargs']['num_layer'],
        node_dim=variant['graph_kwargs']['hidden_dim'],
        output_activation='relu',
    )
    target_cg2 = copy.deepcopy(cg2)

    graph_builder_ca = FullGraphBuilder(
        input_node_dim=obs_dim + action_dim,
        num_node=num_agent,
        batch_size=variant['algorithm_kwargs']['batch_size'],
        contain_self_loop=False)
    cgca = GraphContextNet(
        graph_builder_ca,
        obs_dim,
        action_dim,
        use_attention=variant['graph_kwargs']['use_attention'],
        num_layer=variant['graph_kwargs']['num_layer'],
        node_dim=variant['graph_kwargs']['hidden_dim'],
        output_activation='relu',
    )

    policy_n, expl_policy_n, eval_policy_n = [], [], []
    qf1_n, target_qf1_n, qf2_n, target_qf2_n = [], [], [], []
    cactor_n = []
    for i in range(num_agent):
        from rlkit.torch.networks.networks import FlattenMlp
        qf1 = FlattenMlp(
            input_size=variant['graph_kwargs']['hidden_dim'] + action_dim,
            output_size=1,
            hidden_sizes=[variant['qf_kwargs']['hidden_dim']] *
            (variant['qf_kwargs']['num_layer'] - 1),
        )
        target_qf1 = copy.deepcopy(qf1)
        qf2 = FlattenMlp(
            input_size=variant['graph_kwargs']['hidden_dim'] + action_dim,
            output_size=1,
            hidden_sizes=[variant['qf_kwargs']['hidden_dim']] *
            (variant['qf_kwargs']['num_layer'] - 1),
        )
        target_qf2 = copy.deepcopy(qf2)

        from rlkit.torch.networks.layers import SplitLayer
        cactor = nn.Sequential(
            FlattenMlp(
                input_size=variant['graph_kwargs']['hidden_dim'],
                output_size=variant['cactor_kwargs']['hidden_dim'],
                hidden_sizes=[variant['cactor_kwargs']['hidden_dim']] *
                (variant['cactor_kwargs']['num_layer'] - 1),
            ), nn.ReLU(),
            SplitLayer(layers=[
                nn.Linear(variant['policy_kwargs']['hidden_dim'], action_dim),
                nn.Linear(variant['policy_kwargs']['hidden_dim'], action_dim)
            ]))
        from rlkit.torch.policies.tanh_gaussian_policy import TanhGaussianPolicy
        cactor = TanhGaussianPolicy(module=cactor)

        policy = nn.Sequential(
            FlattenMlp(
                input_size=obs_dim,
                output_size=variant['policy_kwargs']['hidden_dim'],
                hidden_sizes=[variant['policy_kwargs']['hidden_dim']] *
                (variant['policy_kwargs']['num_layer'] - 1),
            ),
            SplitLayer(layers=[
                nn.Linear(variant['policy_kwargs']['hidden_dim'], action_dim),
                nn.Linear(variant['policy_kwargs']['hidden_dim'], action_dim)
            ]))
        policy = TanhGaussianPolicy(module=policy)
        from rlkit.torch.policies.make_deterministic import MakeDeterministic
        eval_policy = MakeDeterministic(policy)
        expl_policy = policy

        policy_n.append(policy)
        expl_policy_n.append(expl_policy)
        eval_policy_n.append(eval_policy)
        qf1_n.append(qf1)
        target_qf1_n.append(target_qf1)
        qf2_n.append(qf2)
        target_qf2_n.append(target_qf2)
        cactor_n.append(cactor)

    from rlkit.samplers.data_collector.ma_path_collector import MAMdpPathCollector
    eval_path_collector = MAMdpPathCollector(eval_env, eval_policy_n)
    expl_path_collector = MAMdpPathCollector(expl_env, expl_policy_n)

    from rlkit.data_management.ma_env_replay_buffer import MAEnvReplayBuffer
    replay_buffer = MAEnvReplayBuffer(variant['replay_buffer_size'],
                                      expl_env,
                                      num_agent=num_agent)

    from rlkit.torch.r2g.r2g_gnn4 import R2GGNNTrainer
    trainer = R2GGNNTrainer(env=expl_env,
                            cg1=cg1,
                            target_cg1=target_cg1,
                            qf1_n=qf1_n,
                            target_qf1_n=target_qf1_n,
                            cg2=cg2,
                            target_cg2=target_cg2,
                            qf2_n=qf2_n,
                            target_qf2_n=target_qf2_n,
                            cgca=cgca,
                            cactor_n=cactor_n,
                            policy_n=policy_n,
                            **variant['trainer_kwargs'])

    from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        log_path_function=get_generic_ma_path_information,
        **variant['algorithm_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()
Esempio n. 6
0
import csv
import os.path
import matplotlib
matplotlib.rcParams.update({'font.size': 10})
from matplotlib import pyplot as plt
import numpy as np
import argparse
# Logger Params
parser = argparse.ArgumentParser()
parser.add_argument('--exp_name', type=str, default='zero_sum')
args = parser.parse_args()

plot_file = './Data/' + args.exp_name + '_reward.png'

from differential_game import DifferentialGame
env = DifferentialGame(game_name=args.exp_name)

xs = np.linspace(-1, 1, 100)
ys = np.linspace(-1, 1, 100)
z1s = np.zeros((100, 100))
z2s = np.zeros((100, 100))

for i, x in enumerate(xs):
    for j, y in enumerate(ys):
        env.reset()
        o_n, r_n, d_n, info = env.step([x, y])
        z1s[j, i] = r_n[0]
        z2s[j, i] = r_n[1]
plt.figure()
plt.subplot(1, 2, 1)
plt.contourf(xs, ys, z1s)
Esempio n. 7
0
def experiment(variant):
    num_agent = variant['num_agent']
    from differential_game import DifferentialGame
    expl_env = DifferentialGame(game_name=args.exp_name)
    eval_env = DifferentialGame(game_name=args.exp_name)
    obs_dim = eval_env.observation_space.low.size
    action_dim = eval_env.action_space.low.size

    qf_n, policy_n, target_qf_n, target_policy_n, eval_policy_n, expl_policy_n = \
        [], [], [], [], [], []
    qf2_n, target_qf2_n = [], []
    for i in range(num_agent):
        from rlkit.torch.networks import FlattenMlp
        qf = FlattenMlp(
            input_size=(obs_dim * num_agent + action_dim * num_agent),
            output_size=1,
            hidden_sizes=[variant['qf_kwargs']['hidden_dim']] * 2,
        )
        target_qf = copy.deepcopy(qf)
        from rlkit.torch.policies.deterministic_policies import TanhMlpPolicy
        policy = TanhMlpPolicy(
            input_size=obs_dim,
            output_size=action_dim,
            hidden_sizes=[variant['policy_kwargs']['hidden_dim']] * 2,
        )
        target_policy = copy.deepcopy(policy)
        eval_policy = policy
        from rlkit.exploration_strategies.base import PolicyWrappedWithExplorationStrategy
        if variant['random_exploration']:
            from rlkit.exploration_strategies.epsilon_greedy import EpsilonGreedy
            expl_policy = PolicyWrappedWithExplorationStrategy(
                exploration_strategy=EpsilonGreedy(expl_env.action_space,
                                                   prob_random_action=1.0),
                policy=policy,
            )
        else:
            from rlkit.exploration_strategies.ou_strategy import OUStrategy
            expl_policy = PolicyWrappedWithExplorationStrategy(
                exploration_strategy=OUStrategy(
                    action_space=expl_env.action_space),
                policy=policy,
            )

        qf_n.append(qf)
        policy_n.append(policy)
        target_qf_n.append(target_qf)
        target_policy_n.append(target_policy)
        eval_policy_n.append(eval_policy)
        expl_policy_n.append(expl_policy)
        if variant['trainer_kwargs']['double_q']:
            qf2 = FlattenMlp(
                input_size=(obs_dim * num_agent + action_dim * num_agent),
                output_size=1,
                hidden_sizes=[variant['qf_kwargs']['hidden_dim']] * 2,
            )
            target_qf2 = copy.deepcopy(qf2)
            qf2_n.append(qf2)
            target_qf2_n.append(target_qf2)

    from rlkit.samplers.data_collector.ma_path_collector import MAMdpPathCollector
    eval_path_collector = MAMdpPathCollector(eval_env, eval_policy_n)
    expl_path_collector = MAMdpPathCollector(expl_env, expl_policy_n)

    from rlkit.data_management.ma_env_replay_buffer import MAEnvReplayBuffer
    replay_buffer = MAEnvReplayBuffer(variant['replay_buffer_size'],
                                      expl_env,
                                      num_agent=num_agent)

    from rlkit.torch.maddpg.maddpg import MADDPGTrainer
    trainer = MADDPGTrainer(qf_n=qf_n,
                            target_qf_n=target_qf_n,
                            policy_n=policy_n,
                            target_policy_n=target_policy_n,
                            qf2_n=qf2_n,
                            target_qf2_n=target_qf2_n,
                            **variant['trainer_kwargs'])

    from rlkit.torch.torch_rl_algorithm import TorchBatchRLAlgorithm
    algorithm = TorchBatchRLAlgorithm(
        trainer=trainer,
        exploration_env=expl_env,
        evaluation_env=eval_env,
        exploration_data_collector=expl_path_collector,
        evaluation_data_collector=eval_path_collector,
        replay_buffer=replay_buffer,
        log_path_function=get_generic_ma_path_information,
        **variant['algorithm_kwargs'])
    algorithm.to(ptu.device)
    algorithm.train()