コード例 #1
0
ファイル: simple_test.py プロジェクト: Cheryl0605/RL4Net
__author__ = "Jiawei Wu"
__copyright__ = "Copyright (c) 2020, Beijing University of Posts and Telecommunications."
__version__ = "0.1.0"
__email__ = "*****@*****.**"



simArgs = {"--maxStep": 10,
            "--routingMethod": 'rl',
            "--adjacencyMatrix": '''[0,1,1,1,-1,0,-1,1,-1,-1,0,1,-1,-1,-1,0]''',
            "--trafficMatrix": '''[{/src/:0,/dst/:3,/rate/:4.736}]'''
            }

print('  start a sim')

env = ns3env.Ns3Env(port=5555, stepTime=0.5, startSim=True, simSeed=0, simArgs=simArgs)
ob_space = env.observation_space
ac_space = env.action_space
print("Observation space: ", ob_space,  ob_space.dtype)
print("Action space: ", ac_space, ac_space.dtype)
obs = env.reset()
print("---obs: ", obs)

stepIdx = 0
try:
    while True:
        mid_weight = stepIdx * 0.1
        lr_weight = (1 - mid_weight) / 2
        action = [lr_weight, lr_weight, mid_weight, 1, 1]
        print("---action: ", action)
コード例 #2
0
 def setup_class(cls):
     """所有测试开始前初始化env"""
     simArgs = {"--maxStep": 2}
     env = ns3env.Ns3Env(port=5555, stepTime=0.5, startSim=True, simSeed=0, simArgs=simArgs, simScriptName='udp-fm')
     cls.env = env
     cls.ob_shape = env.observation_space.shape
コード例 #3
0
ファイル: demo.py プロジェクト: xianliangjiang/RL4Net
                    default=1,
                    help='Number of iterations, Default: 1')
args = parser.parse_args()
startSim = bool(args.start)
iterationNum = int(args.iterations)

port = 5555
simTime = 5  # seconds
stepTime = 0.5  # seconds
seed = 0
simArgs = {"--simTime": simTime, "--stepTime": stepTime, "--testArg": 123}
debug = False

env = ns3env.Ns3Env(port=port,
                    stepTime=stepTime,
                    startSim=startSim,
                    simSeed=seed,
                    simArgs=simArgs,
                    debug=debug)
# simpler:
#env = ns3env.Ns3Env()
env.reset()

ob_space = env.observation_space
ac_space = env.action_space
print("Observation space: ", ob_space, ob_space.dtype)
print("Action space: ", ac_space, ac_space.dtype)

stepIdx = 0
currIt = 0

try:
コード例 #4
0
ファイル: AC-trainer.py プロジェクト: bupt-ipcr/RL4Net-TE
def rl_loop(agent, need_load=True):
    """
    强化学习的主循环
    """
    if need_load:
        START_EPISODE = agent.load()
    else:
        START_EPISODE = 0
    try:
        summary_writer = agent.get_summary_writer()
        for e in range(START_EPISODE, MAX_EPISODE):
            cum_reward = 0
            print("Start episode: ", e, flush=True)
            cur_index = random.randint(0, len(traffic_matrix_state_list) - 1)
            cur_state = traffic_matrix_state_list[
                cur_index]  # 随机选取一个tm作为当前state
            # 为这个特定的cur_state重建环境
            simArgs.update(
                {"--trafficMatrix": traffic_matrix_str_list[cur_index]})
            env = ns3env.Ns3Env(port=5555,
                                stepTime=stepTime,
                                startSim=startSim,
                                simSeed=seed,
                                simArgs=simArgs,
                                debug=debug,
                                simScriptName=simScripteName)
            cur_state = env.reset()
            for s in range(MAX_STEP):
                print("Step: ", s, flush=True)

                # 选取动作
                noise_decay_rate = max((100 - e) / 100, 0.01)
                action = agent.get_action_noise(cur_state,
                                                rate=noise_decay_rate)[0]

                next_state, reward, done, info = env.step(action)

                cum_reward += reward
                # print('cur_state_str: ', traffic_matrix_str_list[cur_index])
                info = {
                    "cur_state": list(cur_state),
                    "action": list(action),
                    "next_state": list(next_state),
                    "reward": reward,
                    "done": done
                }
                print(json.dumps(info))

                agent.add_step(np.array(cur_state), action, reward, done,
                               np.array(next_state))  # 添加到经验回放池

                cur_state = next_state

                if done:
                    break
                # 训练
                agent.learn_batch()
            summary_writer.add_scalar('cum_reward', cum_reward, e)
            agent.save(e)  # 保存网络参数
            if env:
                env.close()
    except KeyboardInterrupt:
        print('正在保存网络参数,请不要退出\r', end='')
        agent.save(e)
        print("Ctrl-C -> Exit                   ")
    finally:
        if summary_writer:
            summary_writer.close()
        if env:
            env.close()
        print("Done!")