コード例 #1
0
ファイル: train.py プロジェクト: schroederdewitt/maddpg
def my_main(_run, _config, _log):
    global mongo_client

    import datetime
    unique_token = "{}__{}".format(_config["name"], datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    # run the framework
    # run(_run, _config, _log, mongo_client, unique_token)
    arglist = parse_args()

    logger = Logger(_log)
    # configure tensorboard logger
    unique_token = "{}__{}".format(arglist.exp_name, datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    use_tensorboard = False
    if use_tensorboard:
        tb_logs_direc = os.path.join(dirname(dirname(abspath(__file__))), "results", "tb_logs")
        tb_exp_direc = os.path.join(tb_logs_direc, "{}").format(unique_token)
        logger.setup_tb(tb_exp_direc)
    logger.setup_sacred(_run)

    train(arglist, logger, _config)
    # arglist = convert(_config)
    #train(arglist)

    # force exit
    os._exit(0)
コード例 #2
0
def run(_run, _config, _log):
    """
    运行,被main函数调用过来
    :param _run:
    :type _run:
    :param _config:
    :type _config:
    :param _log:
    :type _log:
    :return:
    :rtype:
    """
    # 更改一些config中的默认配置,例如cuda,batch等
    _config = args_sanity_check(_config, _log)
    # 改成Namespace范围的参数
    args = SN(**_config)
    args.device = "cuda" if args.use_cuda else "cpu"

    #配置日志
    logger = Logger(_log)

    _log.info("打印实验参数: ")
    experiment_params = pprint.pformat(_config,
                                       indent=4,
                                       width=1)
    _log.info("\n\n" + experiment_params + "\n")

    # 配置Tensorboard Logger , eg: 'qmix_env=8_adam_td_lambda__2021-04-28_09-40-29'
    unique_token = "{}__{}".format(args.name, datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    args.unique_token = unique_token
    # 是否使用tensorboard,使用的话,就配置下存储信息
    if args.use_tensorboard:
        tb_logs_direc = os.path.join(dirname(dirname(abspath(__file__))), "results", "tb_logs")
        tb_exp_direc = os.path.join(tb_logs_direc, "{}").format(unique_token)
        logger.setup_tb(tb_exp_direc)

    # 默认情况下日志sacred来管理
    logger.setup_sacred(_run)

    # 运行和训练
    run_sequential(args=args, logger=logger)

    # Clean up after finishing
    print("退出主程序")

    print("停止所有线程")
    for t in threading.enumerate():
        if t.name != "MainThread":
            print("Thread {} is alive! Is daemon: {}".format(t.name, t.daemon))
            t.join(timeout=1)
            print("Thread joined")

    print("退出 script")

    # 确实退出状态
    os._exit(os.EX_OK)
コード例 #3
0
def run(_run, _config, _log):
    # check args sanity
    _config = args_sanity_check(_config, _log)

    args = SN(**_config)
    args.device = "cuda" if args.use_cuda else "cpu"
    set_device = os.getenv('SET_DEVICE')
    if args.use_cuda and set_device != '-1':
        if set_device is None:
            args.device = "cuda"
        else:
            args.device = f"cuda:{set_device}"
    else:
        args.device = "cpu"

    # setup loggers
    logger = Logger(_log)

    _log.info("Experiment Parameters:")
    experiment_params = pprint.pformat(_config, indent=4, width=1)
    _log.info("\n\n" + experiment_params + "\n")

    # configure tensorboard logger
    unique_token = "{}__{}".format(
        args.name,
        datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    args.unique_token = unique_token
    if args.use_tensorboard:
        tb_logs_direc = os.path.join(dirname(dirname(abspath(__file__))),
                                     "results", "tb_logs")
        tb_exp_direc = os.path.join(tb_logs_direc, "{}").format(unique_token)
        logger.setup_tb(tb_exp_direc)

    # sacred is on by default
    logger.setup_sacred(_run)

    # Run and train
    run_sequential(args=args, logger=logger)

    # Clean up after finishing
    print("Exiting Main")

    print("Stopping all threads")
    for t in threading.enumerate():
        if t.name != "MainThread":
            print("Thread {} is alive! Is daemon: {}".format(t.name, t.daemon))
            t.join(timeout=1)
            print("Thread joined")

    print("Exiting script")

    # Making sure framework really exits
    os._exit(os.EX_OK)
コード例 #4
0
def run(_run, _config, _log):

    # check args sanity
    _config = args_sanity_check(_config, _log)

    args = SN(**_config)
    args.device = "cuda" if args.use_cuda else "cpu"
    if args.use_cuda:
        th.cuda.set_device(args.device_num)

    # setup loggers
    logger = Logger(_log)

    _log.info("Experiment Parameters:")
    experiment_params = pprint.pformat(_config, indent=4, width=1)
    _log.info("\n\n" + experiment_params + "\n")

    # configure tensorboard logger
    unique_token = "{}__{}".format(
        args.name,
        datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    args.unique_token = unique_token
    if args.use_tensorboard:
        tb_logs_direc = os.path.join(dirname(dirname(abspath(__file__))),
                                     "results", "tb_logs")
        tb_exp_direc = os.path.join(tb_logs_direc, "{}").format(unique_token)
        args.tb_logs = tb_exp_direc
        # args.latent_role_direc = os.path.join(tb_exp_direc, "{}").format('latent_role')
        logger.setup_tb(tb_exp_direc)
        #dump config to the tb directory
        with open(os.path.join(tb_exp_direc, "config.yaml"), "w") as f:
            yaml.dump(_config, f, default_flow_style=False)

    # sacred is on by default
    logger.setup_sacred(_run)

    # Run and train
    run_sequential(args=args, logger=logger)

    # Clean up after finishing
    print("Exiting Main")

    print("Stopping all threads")
    for t in threading.enumerate():
        if t.name != "MainThread":
            print("Thread {} is alive! Is daemon: {}".format(t.name, t.daemon))
            t.join(timeout=1)
            print("Thread joined")

    print("Exiting script")

    # Making sure framework really exits
    os._exit(os.EX_OK)
コード例 #5
0
def run(_run, _config, _log):

    # check args sanity
    _config = args_sanity_check(_config, _log)

    args = SN(**_config)
    args.device = "cuda" if args.use_cuda else "cpu"

    # setup loggers
    logger = Logger(_log)

    _log.info("Experiment Parameters:")
    experiment_params = pprint.pformat(_config, indent=4, width=1)
    _log.info("\n\n" + experiment_params + "\n")

    # configure tensorboard logger
    # unique_token = "{}__{}".format(args.name, datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))

    try:
        map_name = _config["env_args"]["map_name"]
    except:
        map_name = _config["env_args"]["key"]
    unique_token = f"{_config['name']}_seed{_config['seed']}_{map_name}_{datetime.datetime.now()}"

    args.unique_token = unique_token
    if args.use_tensorboard:
        tb_logs_direc = os.path.join(dirname(dirname(abspath(__file__))),
                                     "results", "tb_logs")
        tb_exp_direc = os.path.join(tb_logs_direc, "{}").format(unique_token)
        logger.setup_tb(tb_exp_direc)

    # sacred is on by default
    logger.setup_sacred(_run)

    # Run and train
    run_sequential(args=args, logger=logger)

    # Clean up after finishing
    print("Exiting Main")

    print("Stopping all threads")
    for t in threading.enumerate():
        if t.name != "MainThread":
            print("Thread {} is alive! Is daemon: {}".format(t.name, t.daemon))
            t.join(timeout=1)
            print("Thread joined")

    print("Exiting script")
コード例 #6
0
def my_main(_run, _config, _log):
    global mongo_client

    import datetime

    # arglist = parse_args()
    # unique_token = "{}__{}".format(arglist.name, datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    # run the framework
    # run(_run, _config, _log, mongo_client, unique_token)

    logger = Logger(_log)

    # configure tensorboard logger
    unique_token = "{}__{}".format(
        _config["label"],
        datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    use_tensorboard = False
    if use_tensorboard:
        tb_logs_direc = os.path.join(dirname(dirname(abspath(__file__))),
                                     "results", "tb_logs")
        tb_exp_direc = os.path.join(tb_logs_direc, "{}").format(unique_token)
        logger.setup_tb(tb_exp_direc)
    logger.setup_sacred(_run)

    _log.info("Experiment Parameters:")
    import pprint
    experiment_params = pprint.pformat(_config, indent=4, width=1)
    _log.info("\n\n" + experiment_params + "\n")

    # START THE TRAINING PROCESS
    runner = Runner(logger)
    runner.load(_config)
    runner.reset()
    # args = vars(arglist)
    runner.run(_config)

    # runner.run(args)

    # train(arglist, logger, _config)
    # arglist = convert(_config)
    # train(arglist)

    # force exit
    os._exit(0)
コード例 #7
0
def run(_run, _config, _log, pymongo_client):

    # check args sanity
    _config = args_sanity_check(_config, _log)

    args = SN(**_config)
    args.device = "cuda" if args.use_cuda else "cpu"

    # setup loggers
    logger = Logger(_log)

    _log.info("Experiment Parameters:")
    experiment_params = pprint.pformat(_config, indent=4, width=1)
    _log.info("\n\n" + experiment_params + "\n")

    # configure tensorboard logger
    unique_token = "{}__{}".format(
        args.name,
        datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    args.unique_token = unique_token
    if args.use_tensorboard:
        tb_logs_direc = os.path.join(dirname(dirname(abspath(__file__))),
                                     "results", "tb_logs")
        tb_exp_direc = os.path.join(tb_logs_direc, "{}").format(unique_token)
        logger.setup_tb(tb_exp_direc)

    # sacred is on by default
    logger.setup_sacred(_run)

    # Run and train
    if args.cross_play and args.evaluate:
        run_sequential_cross(args=args, logger=logger)
    else:
        run_sequential(args=args, logger=logger)

    # Clean up after finishing
    print("Exiting Main")

    if pymongo_client is not None:
        print("Attempting to close mongodb client")
        pymongo_client.close()
    print("Mongodb client closed")
コード例 #8
0
    def setup_components(logger,
                         agent_state_dict):
        task_names = []
        for task_name, _ in task_configs.items():
            task_names.append(task_name)

        # set up tasks based on the configs
        for task_name, task_config in task_configs.items():

            task_args = task_config

            from copy import deepcopy
            logger = Logger(_log)
            # sacred is on by default
            logger.setup_sacred(_run)
            # logger = deepcopy(meta_logger)
            logger.prefix = task_name
            loggers[task_name] = logger

            # Init runner so we can get env info
            runner = r_REGISTRY[task_args.runner](args=task_args,
                                                  logger=logger)
            runners[task_name] = runner

            # Set up schemes and groups here
            env_info = runner.get_env_info()
            task_args.n_agents = env_info["n_agents"]
            task_args.n_actions = env_info["n_actions"]
            task_args.obs_decoder = dill.loads(env_info["obs_decoder"]) if env_info["obs_decoder"] is not None else None
            task_args.avail_actions_encoder = dill.loads(env_info["avail_actions_encoder_grid"]) if env_info[
                                                                                                   "avail_actions_encoder_grid"] is not None else None
            task_args.db_url = args.db_url
            task_args.db_name = args.db_name
            task_args.state_shape = env_info["state_shape"]
            task_args.state_decoder = dill.loads(env_info["state_decoder"]) if env_info["state_decoder"] is not None else None
            task_args.obs_decoder = dill.loads(env_info["obs_decoder"]) if env_info["obs_decoder"] is not None else None

            # Default/Base scheme
            scheme = {
                "state": {"vshape": env_info["state_shape"]},
                "obs": {"vshape": env_info["obs_shape"], "group": "agents",
                        "vshape_decoded": env_info.get("obs_shape_decoded", env_info["obs_shape"])},
                "actions": {"vshape": (1,), "group": "agents", "dtype": th.long},
                "avail_actions": {"vshape": (env_info["n_actions"],), "group": "agents", "dtype": th.int},
                "reward": {"vshape": (1,)},
                "terminated": {"vshape": (1,), "dtype": th.uint8},
            }
            groups = {
                "agents": task_args.n_agents
            }
            preprocess = {
                "actions": ("actions_onehot", [OneHot(out_dim=task_args.n_actions)])
            }

            buffer = ReplayBuffer(scheme, groups, task_args.buffer_size, env_info["episode_limit"] + 1,
                                  preprocess=preprocess,
                                  device="cpu" if task_args.buffer_cpu_only else args.device)
            buffers[task_name] = buffer

            # Setup multiagent controller here
            mac = mac_REGISTRY[task_args.mac](buffer.scheme, groups, task_args)

            #point model to same object
            macs[task_name] = mac
            mac.agent = macs[task_names[0]].agent

            # Give runner the scheme
            runner.setup(scheme=scheme, groups=groups, preprocess=preprocess, mac=mac)

            # Learner
            learner = le_REGISTRY[task_args.learner](mac, buffer.scheme, logger, task_args)
            learners[task_name] = learner

            if task_args.use_cuda:
                learner.cuda()

            #if agent_state_dict is None:
            #    agent_state_dict = mac.agent.state_dict()
            # else:
            #    # copy all weights that have same dimensions
            #    sd = mac.agent.state_dict()
            #    for k, v in agent_state_dict.items():
            #        if (k in sd) and (v.shape == sd[k].shape):
            #            setattr(mac.agent, k, v)


            if task_args.checkpoint_path != "":

                timesteps = []
                timestep_to_load = 0

                if not os.path.isdir(task_args.checkpoint_path):
                    logger.console_logger.info("Checkpoint directory {} doesn't exist".format(task_args.checkpoint_path))
                    return

                # Go through all files in args.checkpoint_path
                for name in os.listdir(task_args.checkpoint_path):
                    full_name = os.path.join(task_args.checkpoint_path, name)
                    # Check if they are dirs the names of which are numbers
                    if os.path.isdir(full_name) and name.isdigit():
                        timesteps.append(int(name))

                if task_args.load_step == 0:
                    # choose the max timestep
                    timestep_to_load = max(timesteps)
                else:
                    # choose the timestep closest to load_step
                    timestep_to_load = min(timesteps, key=lambda x: abs(x - task_args.load_step))

                model_path = os.path.join(task_args.checkpoint_path, str(timestep_to_load))

                logger.console_logger.info("Loading model from {}".format(model_path))
                learner.load_models(model_path)
                runner.t_env = timestep_to_load

                if task_args.evaluate or task_args.save_replay:
                    evaluate_sequential(task_args, runner)
                    return
        return
コード例 #9
0
def run(_run, _config, _log):

    # check args sanity
    _config = args_sanity_check(_config, _log)

    args = SN(**_config)
    args.device = "cuda" if args.use_cuda else "cpu"

    # setup loggers
    logger = Logger(_log)

    _log.info("Experiment Parameters:")
    experiment_params = pprint.pformat(_config, indent=4, width=1)
    _log.info("\n\n" + experiment_params + "\n")

    print("\n=================")
    auto_describe = 'env_' + args.env_args['map_name'] + '_' + args.name.split(
        '_')[0] + '_'  # env map algo
    if args.name.split('_')[0] == 'smix':
        auto_describe = auto_describe + args.mixer + '_'
    # average and sub-average
    if args.SubAVG_Agent_flag == 0 and args.SubAVG_Mixer_flag == 0:
        pass
    else:
        if args.SubAVG_Agent_flag == 1:
            auto_describe = auto_describe + 'average' + str(
                args.SubAVG_Agent_K)
            if args.SubAVG_Agent_flag_select < 0:
                auto_describe = auto_describe + 'neg' + (
                    "Mean" if args.SubAVG_Agent_name_select_replacement
                    == 'mean' else "Zero_start2") + '_'
            elif args.SubAVG_Agent_flag_select > 0:
                auto_describe = auto_describe + 'pos' + (
                    "Mean" if args.SubAVG_Agent_name_select_replacement
                    == 'mean' else "Zero_start2") + '_'
        if args.SubAVG_Mixer_flag == 1:
            auto_describe = auto_describe + 'mix' + str(args.SubAVG_Mixer_K)
            if args.SubAVG_Mixer_flag_select < 0:
                auto_describe = auto_describe + 'neg' + (
                    "Mean" if args.SubAVG_Mixer_name_select_replacement
                    == 'mean' else "Zero_start2") + '_'
            elif args.SubAVG_Mixer_flag_select > 0:
                auto_describe = auto_describe + 'pos' + (
                    "Mean" if args.SubAVG_Mixer_name_select_replacement
                    == 'mean' else "Zero_start2") + '_'
    # double
    if args.double_q:
        auto_describe = auto_describe + 'Double'
    else:
        auto_describe = auto_describe + 'vanilla'

    args.z_auto_describe = auto_describe
    print(auto_describe)

    print(args.z_describe)

    print("using Sub-AVG Agent:", end=' ')
    print("Yes  K: %d" %
          args.SubAVG_Agent_K if args.SubAVG_Agent_flag else "No",
          end=' | ')
    print("select > or < average:", end=' ')
    print("Yes  signal: %d" % args.SubAVG_Agent_flag_select
          if args.SubAVG_Agent_flag_select else "No",
          end=' | ')
    print("replace average_select by :", end=' ')
    print("mean" if args.SubAVG_Agent_name_select_replacement ==
          'mean' else "zero")

    print("using Sub-AVG Mixer:", end=' ')
    print("Yes  K: %d" %
          args.SubAVG_Mixer_K if args.SubAVG_Mixer_flag else "No",
          end=' | ')
    print("mixer: select > or <:", end=' ')
    print("Yes  signal: %d" % args.SubAVG_Mixer_flag_select
          if args.SubAVG_Mixer_flag_select else "No",
          end=' | ')
    print("replace mixer_select by :", end=' ')
    print("mean" if args.SubAVG_Mixer_name_select_replacement ==
          'mean' else "zero")

    print("using double_q: ", end=' ')
    print('Yes' if args.double_q else "No", end=' | ')

    # -----------------------------
    print("\n==================\n")

    # configure tensorboard logger
    unique_token = "{}__{}".format(
        args.name,
        datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    args.unique_token = unique_token
    if args.use_tensorboard:
        tb_logs_direc = os.path.join(dirname(dirname(abspath(__file__))),
                                     "results", "tb_logs")
        tb_exp_direc = os.path.join(tb_logs_direc, "{}").format(unique_token)
        logger.setup_tb(tb_exp_direc)

    # sacred is on by default
    logger.setup_sacred(_run)

    # Run and train
    run_sequential(args=args, logger=logger)

    # Clean up after finishing
    print("Exiting Main")

    print("Stopping all threads")
    for t in threading.enumerate():
        if t.name != "MainThread":
            print("Thread {} is alive! Is daemon: {}".format(t.name, t.daemon))
            t.join(timeout=1)
            print("Thread joined")

    print("Exiting script")

    # Making sure framework really exits
    os._exit(os.EX_OK)