Exemplo n.º 1
0
def train(model, memory, model_config, env_config, device, weight_file):
    gamma = model_config.getfloat('model', 'gamma')
    batch_size = model_config.getint('train', 'batch_size')
    learning_rate = model_config.getfloat('train', 'learning_rate')
    step_size = model_config.getint('train', 'step_size')
    train_episodes = model_config.getint('train', 'train_episodes')
    sample_episodes = model_config.getint('train', 'sample_episodes')
    test_interval = model_config.getint('train', 'test_interval')
    test_episodes = model_config.getint('train', 'test_episodes')
    epsilon_start = model_config.getfloat('train', 'epsilon_start')
    epsilon_end = model_config.getfloat('train', 'epsilon_end')
    epsilon_decay = model_config.getfloat('train', 'epsilon_decay')
    num_epochs = model_config.getint('train', 'num_epochs')
    kinematic = env_config.getboolean('agent', 'kinematic')
    checkpoint_interval = model_config.getint('train', 'checkpoint_interval')

    criterion = nn.MSELoss().to(device)
    data_loader = DataLoader(memory, batch_size, shuffle=True)
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
    lr_scheduler = optim.lr_scheduler.StepLR(optimizer,
                                             step_size=step_size,
                                             gamma=0.1)
    train_env = ENV(config=env_config, phase='train')
    test_env = ENV(config=env_config, phase='test')
    duplicate_model = copy.deepcopy(model)

    episode = 0
    while episode < train_episodes:
        # epsilon-greedy
        if episode < epsilon_decay:
            epsilon = epsilon_start + (epsilon_end -
                                       epsilon_start) / epsilon_decay * episode
        else:
            epsilon = epsilon_end

        # test
        if episode % test_interval == 0:
            run_k_episodes(test_episodes, episode, model, 'test', test_env,
                           gamma, epsilon, kinematic, None, None, device)
            # update duplicate model
            duplicate_model = copy.deepcopy(model)

        # sample k episodes into memory and optimize over the generated memory
        run_k_episodes(sample_episodes, episode, model, 'train', train_env,
                       gamma, epsilon, kinematic, duplicate_model, memory,
                       device)
        optimize_batch(model, data_loader, len(memory), optimizer, None,
                       criterion, num_epochs, device)
        episode += 1

        if episode != 0 and episode % checkpoint_interval == 0:
            torch.save(model.state_dict(), weight_file)

    return model
Exemplo n.º 2
0
    def test_fileset_with_options(self):
        filesystemName = ENV.get_random_filesystem()
        filesetName = "fileset_" + get_random_string()
        maxNumInodes = 10048
        preA = 8256
        permissionChangeMode = random.choice(PERMISSION_CHANGE_MODE)
        comment = 'comment_' + get_random_string()
        c = RestClient()
        # create a fileset
        self.create_fileset(c, filesystemName, filesetName, False, 'new',
                            preA, maxNumInodes,
                            permissionChangeMode, comment)
        # list just this fileset ( no cache...)
        r = self.list_filesets(c, filesystemName, filesetName)
        res = json.loads(r.text)
        names = [str(s['config']['filesetName']) for s in res['filesets']]
        self.assertTrue(filesetName in names)
        fs = get_obj_from_list(res['filesets'], 'filesetName', filesetName,
                               'config')

#        self.assertTrue(int(fs['config']['maxNumInodes']) >= maxNumInodes)
#        self.assertTrue(int(fs['config']['numInodesToPreallocate']) >= preA)
        self.assertTrue(str(fs['config']['comment']) == comment)
        b = str(fs['config']['permissionChangeMode']) == permissionChangeMode
        self.assertTrue(b)
        # list all filesets ( from cache ? )
        r = self.list_filesets(c, filesystemName)
        res = json.loads(r.text)
        names = [str(s['config']['filesetName']) for s in res['filesets']]
        self.assertTrue(filesetName in names)
        # change some params
        NewfilesetName = "Newfileset_" + get_random_string()
        maxNumInodes = 20048
        preA = 16024
        link = True if 'chmod' in permissionChangeMode else False
        permissionChangeMode = random.choice(PERMISSION_CHANGE_MODE)
        comment = 'comment_' + get_random_string()
        iamMode = random.choice(IAM_MODE)
#        iamMode = 'advisory' #, 'noncompliant', 'compliant'
        self.change_fileset(c, filesystemName, filesetName, link,
                            preA, maxNumInodes, permissionChangeMode,
                            comment, NewfilesetName, iamMode)
        filesetName = NewfilesetName
        # list all filesets ( from cache !)
        r = self.list_filesets(c, filesystemName)
        res = json.loads(r.text)
        names = [str(s['config']['filesetName']) for s in res['filesets']]

        self.assertTrue(filesetName in names)
        fs = get_obj_from_list(res['filesets'], 'filesetName', filesetName,
                               'config')

#        self.assertTrue(int(fs['config']['maxNumInodes']) >= maxNumInodes)
#        self.assertTrue(int(fs['config']['numInodesToPreallocate']) >= preA)
        self.assertTrue(str(fs['config']['comment']) == comment)
        b = str(fs['config']['permissionChangeMode']) == permissionChangeMode
        self.assertTrue(b)
        self.assertTrue(str(fs['config']['iamMode']) == iamMode)

        self.delete_fileset(c, filesystemName, filesetName)
Exemplo n.º 3
0
    def test_fileset_snap(self):
        filesystemName = ENV.get_random_filesystem()
        snapshotName = "snap_fileset_" + get_random_string()
        filesetName = ENV.get_random_fileset(filesystemName)
        c = RestClient()
        self.create_snapshot(c, filesystemName, snapshotName, filesetName)

        r = self.list_snapshots(c, filesystemName, filesetFilter=filesetName)
        res = json.loads(r.text)
        snapNames = [str(s['config']['snapshotName']) for s in res['snapshots']]
        self.assertTrue(snapshotName in snapNames)

        r = self.list_snapshots(c, filesystemName, snapshotName, filesetName)
        res = json.loads(r.text)
        snapNames = [str(s['config']['snapshotName']) for s in res['snapshots']]
        self.assertTrue(snapshotName in snapNames)

        self.delete_snapshot(c, filesystemName, snapshotName, filesetName)
Exemplo n.º 4
0
def _init_env():
    net_map = create_topology()
    # 创建环境对象
    destination = [i for i in range(40, 50)]
    env = ENV(node_list, net_map, NODE_CPT_SCALE, CPT_SWING_RANGE, delta,
              destination)
    # 创建任务池
    task_list = create_taskpool(TASK_NUM=200)
    test_task_list = create_taskpool(TASK_NUM=50)

    return net_map, env, task_list, test_task_list
Exemplo n.º 5
0
    def test_simple_fileset(self):
        filesystemName = ENV.get_random_filesystem()
        filesetName = "fileset_" + get_random_string()
        c = RestClient()
        self.create_fileset(c, filesystemName, filesetName, inodeSpace='root')

        r = self.list_filesets(c, filesystemName, filesetName)
        res = json.loads(r.text)
        names = [str(s['config']['filesetName']) for s in res['filesets']]
        self.assertTrue(filesetName in names)

        r = self.list_filesets(c, filesystemName)
        res = json.loads(r.text)
        names = [str(s['config']['filesetName']) for s in res['filesets']]
        self.assertTrue(filesetName in names)

        self.delete_fileset(c, filesystemName, filesetName)
Exemplo n.º 6
0
def main():
    args = parse_args()
    for key in vars(args).keys():
        print('[*] {} = {}'.format(key, vars(args)[key]))

    save_dir = make_saving_dir(args)
    print(save_dir)
    result = np.zeros((args.n_sample, args.n_trial, args.max_ep, 2))

    for sample in range(args.n_sample):
        env = ENV(mapFile=args.map_name, random=args.random)
        model = {'MBIE': mbie.MBIE(env, args.beta), 'MBIE_NS': mbie.MBIE_NS(env, args.beta),\
            'DH': hindsight.DH(env, bool(args.ent_known),args.beta, args.lambd),\
            'DO': outcome.DO(env, bool(args.ent_known), args.beta, args.lambd)}
        print('sample {} out of {}'.format(sample, args.n_sample))
        env._render()

        np.save(save_dir + "map_sample_{}.npy".format(sample), env.map)
        for trial in range(args.n_trial):
            print('trail = {}'.format(trial))
            mrl = model[args.method]
            mrl.reset()
            for episode in range(args.max_ep):
                terminal = False
                step = 0
                R = []
                s = env.reset()
                while not terminal and step < args.max_step:
                    action = np.random.choice(np.flatnonzero(mrl.Q[s, :] == mrl.Q[s,:].max()))
                    ns, r, terminal = env.step(action)
                    R.append(r)
                    mrl.observe(s,action,ns,r, terminal)
                    step += 1
                    s = ns

                result[sample, trial, episode, 0] = step
                result[sample, trial, episode, 1] = disc_return(R, mrl.gamma)
                mrl.Qupdate()
                print(episode, step, disc_return(R, mrl.gamma), np.max(mrl.Q))
                #print(np.max(mrl.Q, axis=1).reshape(13,13))
            try:
                np.save(save_dir + "entopy_trail_{}_sample_{}.npy".format(trial, sample), mrl.entropy)
            except:
                print("No entropy is saving")
            np.save(save_dir + "count_trail_{}_sample_{}.npy".format(trial, sample), mrl.count)
    np.save(save_dir + 'results.npy', result)
Exemplo n.º 7
0
    state_dim = model_config.getint('model', 'state_dim')
    gamma = model_config.getfloat('model', 'gamma')
    bxmin = env_config.getfloat('sim', 'xmin')
    bxmax = env_config.getfloat('sim', 'xmax')
    bymin = env_config.getfloat('sim', 'ymin')
    bymax = env_config.getfloat('sim', 'ymax')
    xmin = env_config.getfloat('visualization', 'xmin')
    xmax = env_config.getfloat('visualization', 'xmax')
    ymin = env_config.getfloat('visualization', 'ymin')
    ymax = env_config.getfloat('visualization', 'ymax')
    crossing_radius = env_config.getfloat('sim', 'crossing_radius')
    kinematic = env_config.getboolean('agent', 'kinematic')
    radius = env_config.getfloat('agent', 'radius')

    device = torch.device('cpu')
    test_env = ENV(config=env_config, phase='test')
    test_env.reset(case)
    model = ValueNetwork(state_dim=state_dim,
                         fc_layers=[100, 100, 100],
                         kinematic=kinematic)
    model.load_state_dict(
        torch.load(weight_path, map_location=lambda storage, loc: storage))
    _, state_sequences, _, _ = run_one_episode(model, 'test', test_env, gamma,
                                               None, kinematic, device)

    positions = list()
    colors = list()
    counter = list()
    line_positions = list()
    for i in range(len(state_sequences[0])):
        counter.append(i)
Exemplo n.º 8
0
from gevent import monkey

monkey.patch_all()
from gevent.wsgi import WSGIServer
from env import ENV

cli = ENV.parse_cli(['env=', 'port='])

from flask import Flask, abort, request, jsonify, make_response, url_for, g, redirect, Response, current_app

flask_app = Flask(__name__)


@flask_app.route('/tasks', methods=["GET"])
def route_get_tasks():
    pass


@flask_app.route('/tasks', methods=["PUT"])
def route_put_tasks():
    pass


@flask_app.route('/tasks', methods=["POST"])
def route_post_tasks():
    pass


@flask_app.route('/tasks', methods=["DELETE"])
def route_delete_tasks():
    pass
from gevent import monkey
monkey.patch_all()
from gevent.wsgi import WSGIServer
from env import ENV 
cli = ENV.parse_cli(['env=','port=']) 

from flask import Flask, abort, request, jsonify, make_response, url_for, g, redirect, Response,current_app

flask_app = Flask(__name__)

@flask_app.route('/tasks',methods=["GET"])
def route_get_tasks():
    pass

@flask_app.route('/tasks',methods=["PUT"])
def route_put_tasks():
    pass

@flask_app.route('/tasks',methods=["POST"])
def route_post_tasks():
    pass

@flask_app.route('/tasks',methods=["DELETE"])
def route_delete_tasks():
    pass

if __name__ == '__main__':
    http_server = WSGIServer(('', int(cli['--port'])), flask_app)
    print 'SERVING!'
    http_server.serve_forever()
Exemplo n.º 10
0
def train(etd_factor, segment=False):
    env = ENV()
    # 环境初始化
    net_map, task_list, test_task_list = env.net_map, env.task_list, env.test_task_list
    neighbors_list = env.neighbor_list
    # NET_STATES = np.array([random.randint(1, 3) for _ in node_list])
    # 创建邻接节点的list
    # neighbors_list = []
    # for k in node_list:
    #     tmp_ = []
    #     for n in node_list:
    #         if net_map[k, n] != 0:
    #             tmp_.append(n)
    #     neighbors_list.append(tmp_)
    # 创建Agent
    # agent = Agent(n_actions=n_action, n_features=n_features)
    agent = DDPG(2, n_features, DPG_bounds, env)
    # task

    # 记录评估
    evaluation_his = []
    x = []
    with open('record.txt', 'w+') as fp:
        fp.write(
            'Iteration\t\tCost\t\tCounter\t\tDelay\t\tEnergyConsumption\n')
        fp.write('-' * 50)
        fp.write('\n')

    time_counter = 1
    netUpdateFlag = False
    step = 0
    for i in range(iterations):
        task_index = np.random.randint(0, len(task_list))
        task = task_list[task_index]
        step_counter = 0
        observation = _init_observation(task, env)
        des_node = task[3]['des_node']

        tmp_path = env.path_list[des_node - 40]

        ec = 0
        delay = 0
        # Tabu = []
        while True:
            present_node = one_hot_decode(observation[4:4 + NODE_NUM])
            if time_counter % change_rounds == 0:
                netUpdateFlag = True

            if segment:
                if observation[0] > 0:
                    action = agent.choose_action(observation)
                else:
                    action = tmp_path[present_node]
            else:
                # try process
                # 确定该节点的有效邻接节点
                action = agent.choose_action(
                    observation)  # TODO(Wezi): check the action dimension
            result, ec_, delay_ = env.perceive(
                observation, action, etd_factor,
                netUpdateFlag)  # result = [r,s']
            netUpdateFlag = False
            ec += ec_
            delay += delay_
            agent.store_transition(observation, action, result[0], result[1])
            # print(result[0])
            # if action <= max(node_list):
            #     step_counter += 1
            time_counter += 1
            step += 1
            step_counter += 1
            observation = result[1]
            if step > 200 and (step % 10 == 0):
                # DQN学习过程
                agent.learn()

            if one_hot_decode(observation[4:54]) == des_node:
                break

            # 保存模型
            if i % 200 == 0 and i >= 400:
                agent.saveModel(i)

        if i >= 300 and i % 100 == 0:
            res_cost, res_counter, latency, res_ec = evaluation(agent,
                                                                env,
                                                                test_task_list,
                                                                neighbors_list,
                                                                change_rounds,
                                                                segment=False)
            with open('record.txt', 'a+') as fp:
                fp.write('%d\t\t%f\t\t%f\t\t%f\t\t%f\n' %
                         (i, res_cost, res_counter, latency, res_ec))
            evaluation_his.append([res_cost, res_counter, latency, res_ec])
            x.append(i)

        # if i > 500 and (i % 1000 == 0):
        #     print(agent_list[0].DQN.fetch_eval(np.array(initial_observation)))

        print("the %d time cost %d rounds!" % (i + 1, step_counter + 1),
              end='')
        print("the ec:\t%f\tthe delay:\t%f" % (ec, delay))
    # 记录
    cost_his = [each[0] for each in evaluation_his]
    counter_his = [each[1] for each in evaluation_his]
    latency_his = [each[2] for each in evaluation_his]
    ec_his = [each[3] for each in evaluation_his]

    fig = plt.figure()
    ax1 = fig.add_subplot(221)
    ax1.plot(x, cost_his)
    ax1.set_title('cost_his')
    ax2 = fig.add_subplot(222)
    ax2.plot(x, counter_his)
    ax2.set_title('round_his')
    ax3 = fig.add_subplot(223)
    ax3.plot(x, latency_his)
    ax3.set_title('latency_his')
    ax4 = fig.add_subplot(224)
    ax4.plot(x, ec_his)
    ax4.set_title('ec_his')
def test_step_with_kinematic():
    env_config = configparser.RawConfigParser()
    env_config.read('configs/test_env.config')
    env_config.set('agent', 'kinematic', 'true')
    test_env = ENV(env_config, phase='test')
    test_env.reset()

    # test state computation
    states, rewards, done_signals = test_env.step((Action(1, 0), Action(1, 0)))
    assert np.allclose(
        states[0], JointState(-1, 0, 1, 0, 0.3, 2, 0, 1.0, 0, 1, 0, -1, 0,
                              0.3))
    assert np.allclose(
        states[1],
        JointState(1, 0, -1, 0, 0.3, -2, 0, 1.0, np.pi, -1, 0, 1, 0, 0.3))
    assert rewards == [0, 0]
    assert done_signals == [False, False]

    # test one-step lookahead
    reward, end_time = test_env.compute_reward(0, [Action(1.5, 0), None])
    assert reward == -0.25
    assert end_time == 1

    reward, end_time = test_env.compute_reward(
        0, [Action(1.5, 0), Action(1.5, 0)])
    assert reward == -0.25
    assert end_time == 0.5

    # test collision detection
    states, rewards, done_signals = test_env.step((Action(1, 0), Action(1, 0)))
    assert np.allclose(
        states[0], JointState(0, 0, 1, 0, 0.3, 2, 0, 1.0, 0, 0, 0, -1, 0, 0.3))
    assert np.allclose(
        states[1],
        JointState(0, 0, -1, 0, 0.3, -2, 0, 1.0, np.pi, 0, 0, 1, 0, 0.3))
    assert rewards == [-0.25, -0.25]
    assert done_signals == [2, 2]

    # test reaching goal
    test_env = ENV(env_config, phase='test')
    test_env.reset()
    test_env.step((Action(1, np.pi / 2), Action(2, np.pi / 2)))
    test_env.step((Action(4, -np.pi / 2), Action(4, -np.pi / 2)))
    states, rewards, done_signals = test_env.step(
        (Action(1, -np.pi / 2), Action(2, -np.pi / 2)))
    assert rewards == [1, 1]
    assert done_signals == [1, 1]
Exemplo n.º 12
0
def p_loop(EPISODE, GAMMA, LAMBDA, ALPHA, path):
    """
    训练函数
    """
    # 初始化w
    try:
        w = np.load(path)
        print("Load {}".format(path))
        print("-" * 30)
    except:
        w = np.zeros((12 * 12 * 12 * 12 * 4, 1))
        print("Initialize Value")
        print("-" * 30)

    # 初始化Feature_Encoder & Actor
    encoder = FEATURE_ENCODER(ACTION)
    actor = ACTOR(encoder, ACTION, is_train=True)

    # 初始化训练参数
    step_a = INTERVAL_A / INTERVAL_ENV

    # 循环更新
    for ep in range(EPISODE):
        # 训练log记录
        w_hist = []
        r_hist = []

        # 初始化资格迹
        et = np.zeros_like(w)

        #随机初始化环境和状态
        e = ENV()

        # 初次动作生成a_t & 特征编码s_t & 状态更新
        a = actor.act([e.c.dx, e.c.dy, e.c.vx, e.c.vy], w)
        en = encoder.encode([e.c.dx, e.c.dy, e.c.vx, e.c.vy], a)
        e.update(a)

        for t in range(int(T / INTERVAL_ENV)):
            # 动作仿真
            if t % step_a == 0:
                # 更新动作 a_{t+1}
                a_new = actor.act([e.c.dx, e.c.dy, e.c.vx, e.c.vy], w)
                # 更新特征 s_{t+1}
                en_new = encoder.encode([e.c.dx, e.c.dy, e.c.vx, e.c.vy],
                                        a_new)
                # 计算delta
                delta = e.r + GAMMA * np.matmul(en_new.T, w) - np.matmul(
                    en.T, w)
                # 更新资格迹
                et = GAMMA * LAMBDA * et + en
                # 更新参数矩阵w
                w += ALPHA * delta * et

                a = a_new
                en = en_new

                # Log记录
                w_hist.append(np.sum(np.abs(delta)))
                r_hist.append(e.r)

            # 状态仿真
            e.update(a)

        # Log输出
        w_hist = np.array(w_hist)
        r_hist = np.array(r_hist)
        print(
            "EP{}:  delta_w:{:.2f}  total_r:{:.2f}  final_dist:{:.2f}  Vx:{:.2f}  Vy:{:.2f}"
            .format(ep + 1, np.sum(w_hist), np.sum(r_hist), -e.r, e.c.vx,
                    e.c.vy))

        # 每10个ep存储一次参数矩阵w
        if (ep + 1) % 10 == 0:
            np.save(path, w)
            print("Saved in {}".format(path))
            print("-" * 30)
Exemplo n.º 13
0
def main():
	original_size = (782, 600)
	env = ENV(actions, (original_size[0]/6, original_size[1]/6))
	gamma = 0.9
	epsilon = .95
	model_ph = 'models'
	if not os.path.exists(model_ph):
		os.mkdir(model_ph)
	trials = 500
	trial_len = 1000
	rewards = []
	q_values = []

	dqn_agent = DQN(env=env)
	success_num = 0
	rewards = []
	q_values = []
	Q = []
	for trial in range(1, trials):
		t_reward = []
		t_qvalue = []
		cur_state = env.reset()
		for step in range(trial_len):
			action = dqn_agent.act(cur_state)
			new_state, reward, done, success = env.step(action)
			t_reward.append(reward)
			
			# reward = reward if not done else -20
			dqn_agent.remember(cur_state, action, reward, new_state, done)
	
			q_value = dqn_agent.replay()  # internally iterates default (prediction) model
			if q_value:
				t_qvalue.append(q_value)
				Q.append(q_value)
			else:
				t_qvalue.append(0.0)
				Q.append(0.0)
			dqn_agent.target_train()  # iterates target model
			cur_state = new_state

			dqn_agent.log_result()

			save_q(Q)

			if success:
				success_num += 1
				dqn_agent.step = 100
				print("Completed in {} trials".format(trial))
				dqn_agent.save_model(os.path.join(model_ph, "success-model.h5"))
				break
			if done:
				print("Failed to complete in trial {}, step {}".format(trial, step))
				dqn_agent.save_model(os.path.join(model_ph, "trial-{}-model.h5").format(trial))
				break
		rewards.append(np.sum(t_reward) if t_reward else 0.0)
		q_values.append(np.mean(t_qvalue) if t_qvalue else 0.0)
		
		with open('reward_and_Q/reward.txt', 'wb') as f:
			pickle.dump(rewards, f)
		with open('reward_and_Q/qvalue.txt', 'wb') as f:
			pickle.dump(q_values, f)
		print('trial: {}, success acc: {}'.format(trial, success_num / float(trial)))
Exemplo n.º 14
0
else:
    contract_path = os.path.join("./config/contracts",
                                 contract_filename + ".json")

config = json.load(open(contract_path))

template_filename = args.template
if not template_filename:
    template_path = os.path.join("./config/templates",
                                 default_args['template'] + ".txt")
else:
    template_path = os.path.join("./config/templates",
                                 template_filename + ".txt")

with open(template_path, 'r') as f:
    t = ENV.from_string(f.read())
    f.close()

with open('./config/stylesheet.css') as f:
    template_styles = f.read()
    f.close()

mainnet_url = args.mainnet_url
if not mainnet_url:
    mainnet_url = default_args['mainnet_url']

# wkhtmltopdf_path = "../../vr_nfts/wkhtmltopdf"
# pdf_config = Configuration(wkhtmltopdf=wkhtmltopdf_path)

w3 = Web3(Web3.HTTPProvider(mainnet_url))
Exemplo n.º 15
0
if __name__ == '__main__':
    NODE_NUM = 50
    n_features = NODE_NUM * 3 + 14
    n_actions = 2
    # etd_factor = 1
    etd_factor_list = [0.1, 0.5, 2, 10]
    iteration = 100
    Iter = 10
    batch = 128
    etd_factor = 10
    if os.path.exists("topo.pk"):
        with open("topo.pk", 'rb') as f:
            env = pickle.load(f)
    else:
        env = ENV()
        with open("topo.pk", "wb") as f:
            pickle.dump(env, f)

    resBuffer = {'reward': [], 'ec': [], 'delay': []}
    model = DDPG(n_features, n_actions, env, etd_factor)
    loss = {}
    for i in range(Iter):
        history = model.train(iteration, 128, Iter, i)
        loss[str(i)] = history['Loss']
        model.evaluate(resBuffer)
    loss_his = pd.DataFrame(loss)
    loss_his.to_csv('ddpg_loss_etd_1.csv')
    res = pd.DataFrame(resBuffer)
    res.to_csv("ddpg_etd_1.csv")
    ax = res.plot(grid=True)
Exemplo n.º 16
0
def main(EP, VIS, path, FAST):
    # 初始化仿真参数
    step_a = INTERVAL_A / INTERVAL_ENV
    # 初始化特征编码器 & 动作生成器
    encoder = FEATURE_ENCODER(ACTION)
    actor = ACTOR(encoder, ACTION, is_train=False)
    # 加载参数矩阵
    try:
        w = np.load(path)
        print("Load {}".format(path))
        print("-" * 30)
    except:
        print("Could not find {}".format(path))
        return 0
    # 实时可视化的初始化设置
    if VIS:
        plt.ion()
        plt.figure(figsize=(5, 5))
        plt.axis([0, 100, 0, 100])

    for ep in range(EP):
        sys.stdout.write("EP:{} ".format(ep + 1))
        # 初始化环境
        # e = ENV(w=100, h=100, target=[85.0, 85.0], c_x=10.0, c_y=10.0, c_vx=0.0, c_vy=0.0)
        # e = ENV(w=100, h=100, c_vx=0.0, c_vy=0.0)
        e = ENV(w=100, h=100)

        # 可视化
        if VIS:
            plt.scatter(e.target[0], e.target[1], s=30, c='red')
        else:
            track_x = []
            track_y = []

        for t in range(int(T / INTERVAL_ENV)):

            if t % step_a == 0:
                a = actor.act([e.c.dx, e.c.dy, e.c.vx, e.c.vy], w)

            e.update(a)

            # 可视化
            if VIS and t % FAST == 0:
                sys.stdout.write(
                    "Ep:{}-{}  Vx:{:.2f}  Vy:{:.2f}  Action:{}        \r".
                    format(ep, t + 1, e.c.vx, e.c.vy, a))
                sys.stdout.flush()

                plt.scatter(e.c.x, e.c.y, s=10, c='blue', alpha=0.2)
                plt.scatter(e.target[0], e.target[1], s=30, c='red')
                plt.pause(0.01)
            elif not VIS:
                track_x.append(e.c.x)
                track_y.append(e.c.y)
                str_out = "processing"
                if (t + 1) % 300 == 0:
                    sys.stdout.write(str_out[(t + 1) // 300 - 1])
                    sys.stdout.flush()

        print(
            "  Final_distance:{:.2f}                            ".format(-e.r))

        if VIS:
            plt.scatter(e.c.x, e.c.y, s=30, c='orange')
            plt.text(e.c.x, e.c.y - 1, "EP{} Dist:{:.2f}".format(ep + 1, -e.r))
            plt.pause(5)
        if not VIS:
            plt.scatter(track_x, track_y, s=5, c='blue', alpha=0.2)
            plt.scatter(e.target[0], e.target[1], s=30, c='red')
            plt.scatter(track_x[-1], track_y[-1], s=30, c='orange')
            plt.text(e.c.x, e.c.y - 1, "Dist:{:.2f}".format(-e.r))
            plt.axis([0, 100, 0, 100])
            plt.show()