示例#1
0
def train(nameIndx):
    T_REWARD = []
    MU_REWARD = 0
    BEST_R = -1000
    env = Test(nameIndx) #0 = right

    # agent = DDPG(a_dim, s_dim, a_bound, SIDE[nameIndx])
    agent = DDPG(act_dim=8, obs_dim=42,
                    lr_actor=0.0001, lr_q_value=0.001, gamma=0.99, tau=0.01, action_noise_std=1, name=SIDE[nameIndx])

    var = 0.8  # control exploration
    rar = 0.3
    cnt = 0
    
    t1 = time.time()
    for i in range(MAX_EPISODES):
        t2, t3, t23, t32 = 0., 0., 0., 0.
        s = env.reset() 
        ep_reward = 0
        for j in range(MAX_EP_STEPS):
            a = agent.choose_action(s)
            # a = np.clip(np.random.normal(a, var), -1, 1)    # add randomness to action selection for exploration
            t2 = time.time()
            if t3 != 0:t32 += (t2-t3)
            s_, r, done, info = env.step(a)
            t3 = time.time()
            t23 += (t3-t2)
            agent.memory.store_transition(s, a, r/10, s_, done)
  
            if cnt > MEMORY_CAPACITY:
                if cnt%100 == 0:
                    agent.learn(cnt)
                    # agent.learn()
                else:
                    agent.learn(False)
                    # agent.learn()
            s = s_
            ep_reward += r
            cnt+=1
        
        
        if len(T_REWARD) >= 100:
            T_REWARD.pop(0)
        T_REWARD.append(ep_reward)
        r_sum = 0
        for k in T_REWARD:
            r_sum += k
        MU_REWARD = r_sum/100
        BEST_R = MU_REWARD if MU_REWARD>BEST_R else BEST_R
        print('Episode:', i, ' Reward: %i' % int(ep_reward), 'MU_REWARD: ', int(MU_REWARD),'BEST_R: ', int(BEST_R), 'cnt = ',j , 't_step:', int(t23), 't_learn: ', int(t32)) #'var: %.3f' % var, 'rar: %.3f' % rar)
        if MU_REWARD > GOAL_REWARD:
            break

    if os.path.isdir(agent.path): shutil.rmtree(agent.path)
    os.mkdir(agent.path)
    ckpt_path = os.path.join(agent.path, 'DDPG.ckpt')
    save_path = agent.saver.save(agent.sess, ckpt_path, write_meta_graph=False)
    print("\nSave Model %s\n" % save_path)
    print('Running time: ', time.time() - t1)
sess = tf.Session()
tf.set_random_seed(1)
plt.style.use('seaborn')

for gamma in Gamma:
    agents = DDPG(n_agents=N_AGENTS,
                  gamma=gamma,
                  memory_size=MEMORY_SIZE,
                  train=True)

    s, info = env.reset()  # 注意这里!!!
    rs = []
    epi_r = 0
    for i in range(10000):
        a = agents.choose_action(s)

        s_, r, info_ = env.step(
            a)  # 注意这里!!! info is a list [total_flow, normal_flow]

        agents.store_transition(s, a, r, s_,
                                info)  # 注意这里!!! info和s对应;info_和s_对应
        rs.append(r)
        epi_r = 0.1 * r + 0.9 * epi_r

        if i > 64 and i < 10000 - 1000:  # 最后1000个时间步用做测试
            agents.learn()

        s = s_
        info = info_  # 注意这里!!!
示例#3
0
agent = DDPG(a_dim=6, s_dim=17, a_bound=1, lr_a=0.0001, lr_c=0.001, seed=1)
exploration_rate = 0.2
np.random.seed(1)
env.seed(1)
tf.set_random_seed(1)

total_reward = []
for episode in range(1000):
    state = env.reset()
    var = 0.1
    cum_reward = 0
    for step in range(1000):
        # action = np.clip(np.random.normal(np.reshape(agent.choose_action(state), [6, ]), var), -1, 1)
        if np.random.uniform() > exploration_rate:
            action = np.clip(
                np.random.normal(np.reshape(agent.choose_action(state), [
                    6,
                ]), var), -1, 1)
        else:
            action = env.action_space.sample()
        next_state, reward, done, _ = env.step(action)
        # print(action)
        cum_reward += reward
        agent.store_transition(state, action, reward, next_state, done)
        state = next_state
        agent.learn()
        if done:
            print('Episode', episode, ' Complete at reward ', cum_reward,
                  '!!!')
            # print('Final velocity x is ',state[9])
            # print('Final velocity z is ',state[10])
示例#4
0
文件: run_ddpg.py 项目: shilx001/HRL
import gym
from ddpg import DDPG
import matplotlib.pyplot as plt
import pickle

env = gym.make('MountainCarContinuous-v0')

agent = DDPG(a_dim=1, s_dim=2, a_bound=1)

total_reward = []
for episode in range(300):
    state = env.reset()
    var = 1
    cum_reward = 0
    for step in range(1000):
        action = np.clip(np.random.normal(agent.choose_action(state), var), -1,
                         1)
        next_state, reward, done, _ = env.step([action])
        # print(action)
        cum_reward += reward
        agent.store_transition(state, action, reward, next_state)
        state = next_state
        if done:
            print('Episode', episode, ' Complete at reward ', cum_reward,
                  '!!!')
            break
        if step == 1000 - 1:
            print('Episode', episode, ' finished at reward ', cum_reward)
    total_reward.append(cum_reward)
    if var > 0.1:
        var -= 0.01
示例#5
0
env = gym.make('MountainCarContinuous-v0')  # ('Hopper-v1')
print(env.action_space, env.observation_space)
print(env.action_space.low, env.action_space.high)
n_actions = 1
n_states = 2

ddpg = DDPG(n_actions=n_actions, n_states=n_states, opt=config)

returns = []
for i in xrange(10000):
    ddpg.reset(0.1)
    state = env.reset()
    total_reward = 0.0
    for t in count():

        action = ddpg.choose_action(state)
        next_state, reward, done = ddpg.apply_action(env, action)
        # env.render()
        ddpg.replay_memory.push(state=state,
                                action=action,
                                next_state=next_state,
                                terminate=done,
                                reward=reward)
        total_reward += reward

        state = next_state
        if done:
            break

        if len(ddpg.replay_memory) > 100:
            ddpg.update()
示例#6
0
a_dim = env.action_space.shape[0]                       # 动作空间维度
a_bound = env.action_space.low, env.action_space.high   # 动作取值上下界

ddpg = DDPG(s_dim, a_dim, a_bound,
            MEMORY_CAPACITY, BATCH_SIZE,
            GAMMA, ALPHA_A, ALPHA_C, TAO)
ddpg.initail_net('./result.ckpt')

for each_episode in range(MAX_EPISODES):

    ep_reward = 0
    s = env.reset()
    for each_step in range(MAX_EP_STEPS):

        if RENDER:

            env.render()

        a = ddpg.choose_action(s[np.newaxis, :])[0]
        print(a)

        s_, r, done, _ = env.step(a)

        s = s_
        ep_reward += r

        if each_step == MAX_EP_STEPS - 1:

            print('Episode:', each_episode, ' Reward: %i' % int(ep_reward))
            break
示例#7
0
class SmartAgent(object):
    def __init__(self):
        # from the origin base.agent
        self.reward = 0
        self.episodes = 0
        self.steps = 0
        self.obs_spec = None
        self.action_spec = None

        self.ddpg = DDPG(
            a_dim=len(smart_actions),
            s_dim=
            11,  # one of the most important data that needs to be update manually
        )

        # self defined vars
        self.fighting = False
        self.player_hp = []
        self.enemy_hp = []
        self.previous_enemy_hp = []
        self.previous_player_hp = []
        self.leftover_enemy_hp = []
        self.win = 0
        self.count = 0

        self.previous_action = None
        self.previous_state = None

    def step(self, obs):

        # from the origin base.agent
        self.steps += 1
        self.reward += obs.reward

        current_state, enemy_hp, player_hp, enemy_loc, player_loc, distance, selected, enemy_count, player_count = self.extract_features(
            obs)

        self.player_hp.append(sum(player_hp))
        self.enemy_hp.append(sum(enemy_hp))

        # scripted the few initial actions to increases the learning performance
        while not self.fighting:
            for i in range(0, player_count):
                if distance[i] < 20:
                    self.fighting = True
                    #return actions.FunctionCall(_NO_OP, [])

            return actions.FunctionCall(_ATTACK_SCREEN,
                                        [_NOT_QUEUED, enemy_loc[0]])
            # Default case => Select unit
            # select the unit that is closest to the enemy
            # if same distance, pick the one with lower hp
            # if same distance and hp, randomly select one
        closest_indices = []
        closest_index = distance.index(min(distance))

        for i in range(0, player_count):
            if distance[i] == distance[closest_index]:
                closest_indices.append(i)

        lowest_hp_indices = []
        lowest_hp_index = player_hp.index(min(player_hp))

        for i in range(0, player_count):
            if player_hp[i] == player_hp[lowest_hp_index]:
                lowest_hp_indices.append(i)

        common_indices = list(
            set(closest_indices).intersection(lowest_hp_indices))

        if len(common_indices) != 0:
            selected_index = random.choice(common_indices)
        elif len(closest_indices) != 0:
            selected_index = random.choice(closest_indices)
        else:
            selected_index = 0

        if selected[selected_index] == 0 or (selected[0] == 1
                                             and selected[1] == 1):
            return actions.FunctionCall(
                _SELECT_POINT, [_NOT_QUEUED, player_loc[selected_index]])

        rl_action = self.ddpg.choose_action(np.array(current_state))
        smart_action = smart_actions[rl_action]

        # record the transitions to memory and learn by DQN
        if self.previous_action is not None:
            reward = self.get_reward(obs, distance, player_hp, enemy_hp,
                                     player_count, enemy_count, selected,
                                     player_loc, enemy_loc)

            self.ddpg.store_transition(np.array(self.previous_state),
                                       self.previous_action, reward,
                                       np.array(current_state))

        self.previous_state = current_state
        self.previous_action = rl_action
        self.previous_enemy_hp = enemy_hp
        self.previous_player_hp = player_hp

        next_action = self.perform_action(obs, smart_action, player_loc,
                                          enemy_loc, selected, player_count,
                                          enemy_count, distance, player_hp)

        return next_action

    def get_reward(self, obs, distance, player_hp, enemy_hp, player_count,
                   enemy_count, selected, unit_locs, enemy_locs):
        reward = 0.

        # give reward by calculating opponents units lost hp
        # for i in range(0, DEFAULT_ENEMY_COUNT):
        #     reward += ((ENEMY_MAX_HP - enemy_hp[i]) * 2)

        # give reward by remaining player units hp
        # for i in range(0, DEFAULT_PLAYER_COUNT):
        #     reward += (player_hp[i])
        #     reward -= (distance[i] - 10) ** 2
        #
        # if reward < 0:
        #     reward = 0

        selected_index = -1

        for i in range(0, DEFAULT_PLAYER_COUNT):
            if selected[i] == 1:
                selected_index = i

        x = unit_locs[selected_index][0]
        y = unit_locs[selected_index][1]

        if distance[selected_index] < 6 or distance[selected_index] > 20:
            reward -= 1
        else:
            reward = distance[selected_index] / 20
        # get killed and lost unit reward from the map
        # reward = int(reward)

        return reward

    # extract all the desired features as inputs for the DQN
    def extract_features(self, obs):
        var = obs.observation['feature_units']
        # get units' location and distance
        enemy, player = [], []

        # get health
        enemy_hp, player_hp = [], []

        # record the selected army
        is_selected = []

        # unit_count
        enemy_unit_count, player_unit_count = 0, 0

        for i in range(0, var.shape[0]):
            if var[i][_UNIT_ALLIANCE] == _PLAYER_HOSTILE:
                enemy.append((var[i][_UNIT_X], var[i][_UNIT_Y]))
                enemy_hp.append(var[i][_UNIT_HEALTH] + +var[i][_UNIT_SHIELD])
                enemy_unit_count += 1
            else:
                player.append((var[i][_UNIT_X], var[i][_UNIT_Y]))
                player_hp.append(var[i][_UNIT_HEALTH])
                is_selected.append(var[i][_UNIT_IS_SELECTED])

                if var[i][_UNIT_HEALTH] < 20:
                    self.count += 1

                player_unit_count += 1

        # append if necessary so that maintains fixed length for current state
        for i in range(player_unit_count, DEFAULT_PLAYER_COUNT):
            player.append((-1, -1))
            player_hp.append(0)
            is_selected.append(-1)

        for i in range(enemy_unit_count, DEFAULT_ENEMY_COUNT):
            enemy.append((-1, -1))
            enemy_hp.append(0)

        # get distance
        min_distance = [100000 for x in range(DEFAULT_PLAYER_COUNT)]

        for i in range(0, player_unit_count):
            for j in range(0, enemy_unit_count):
                distance = int(
                    math.sqrt((player[i][0] - enemy[j][0])**2 +
                              (player[i][1] - enemy[j][1])**2))

                if distance < min_distance[i]:
                    min_distance[i] = distance

        # flatten the array so that all features are a 1D array
        feature1 = np.array(enemy_hp).flatten()  # enemy's hp
        feature2 = np.array(player_hp).flatten()  # player's hp
        feature3 = np.array(enemy).flatten()  # enemy's coordinates
        feature4 = np.array(player).flatten()  # player's coordinates
        feature5 = np.array(min_distance).flatten()  # distance

        # combine all features horizontally
        current_state = np.hstack(
            (feature1, feature2, feature3, feature4, feature5))

        return current_state, enemy_hp, player_hp, enemy, player, min_distance, is_selected, enemy_unit_count, player_unit_count

    # make the desired action calculated by DQNLR_C
    def perform_action(self, obs, action, unit_locs, enemy_locs, selected,
                       player_count, enemy_count, distance, player_hp):
        index = -1

        for i in range(0, DEFAULT_PLAYER_COUNT):
            if selected[i] == 1:
                index = i

        x = unit_locs[index][0]
        y = unit_locs[index][1]

        if action == ATTACK_TARGET:
            if _ATTACK_SCREEN in obs.observation["available_actions"]:
                if enemy_count >= 1:
                    return actions.FunctionCall(
                        _ATTACK_SCREEN,
                        [_NOT_QUEUED, enemy_locs[0]])  # x,y => col,row

        elif action == MOVE_UP:
            if _MOVE_SCREEN in obs.observation[
                    "available_actions"] and index != -1:
                x = x
                y = y - 4

                if 3 > x:
                    x = 3
                elif x > 79:
                    x = 79

                if 3 > y:
                    y = 3
                elif y > 59:
                    y = 59

                return actions.FunctionCall(
                    _MOVE_SCREEN, [_NOT_QUEUED, [x, y]])  # x,y => col,row

        elif action == MOVE_DOWN:
            if _MOVE_SCREEN in obs.observation[
                    "available_actions"] and index != -1:
                x = x
                y = y + 4

                if 3 > x:
                    x = 3
                elif x > 79:
                    x = 79

                if 3 > y:
                    y = 3
                elif y > 59:
                    y = 59

                return actions.FunctionCall(_MOVE_SCREEN,
                                            [_NOT_QUEUED, [x, y]])

        elif action == MOVE_LEFT:
            if _MOVE_SCREEN in obs.observation[
                    "available_actions"] and index != -1:
                x = x - 4
                y = y

                if 3 > x:
                    x = 3
                elif x > 79:
                    x = 79

                if 3 > y:
                    y = 3
                elif y > 59:
                    y = 59

                return actions.FunctionCall(_MOVE_SCREEN,
                                            [_NOT_QUEUED, [x, y]])

        elif action == MOVE_RIGHT:
            if _MOVE_SCREEN in obs.observation[
                    "available_actions"] and index != -1:
                x = x + 4
                y = y

                if 3 > x:
                    x = 3
                elif x > 79:
                    x = 79

                if 3 > y:
                    y = 3
                elif y > 59:
                    y = 59

                return actions.FunctionCall(_MOVE_SCREEN,
                                            [_NOT_QUEUED, [x, y]])

        return actions.FunctionCall(_MOVE_SCREEN, [_NOT_QUEUED, [x, y]])

    # This is not used in current version
    # get_disabled_actions filters the redundant actions from the action space
    def get_disabled_actions(self, player_loc, selected):
        disabled_actions = []

        index = -1

        for i in range(0, DEFAULT_PLAYER_COUNT):
            if selected[i] == 1:
                index = i
                break

        x = player_loc[index][0]
        y = player_loc[index][1]

        # not selecting attack target if the previous actions is already attack target
        if self.previous_action == smart_actions.index(ATTACK_TARGET):
            disabled_actions.append(smart_actions.index(ATTACK_TARGET))  #0

        # not selecting a specific move action if the unit cannot move toward that direction (at the border)
        if y <= 7:
            disabled_actions.append(smart_actions.index(MOVE_UP))  #1

        if y >= 56:
            disabled_actions.append(smart_actions.index(MOVE_DOWN))  #2

        if x <= 7:
            disabled_actions.append(smart_actions.index(MOVE_LEFT))  #3

        if x >= 76:
            disabled_actions.append(smart_actions.index(MOVE_RIGHT))  #4

        # not selecting the same unit if the previous actions already attempts to select it
        if self.previous_action == smart_actions.index(ACTION_SELECT_UNIT):
            disabled_actions.append(
                smart_actions.index(ACTION_SELECT_UNIT))  #5

        return disabled_actions

    def plot_hp(self, path, save):
        plt.plot(np.arange(len(self.player_hp)), self.player_hp)
        plt.ylabel('player hp')
        plt.xlabel('training steps')
        if save:
            plt.savefig(path + '/player_hp.png')
        plt.close()

        plt.plot(np.arange(len(self.enemy_hp)), self.enemy_hp)
        plt.ylabel('enemy hp')
        plt.xlabel('training steps')
        if save:
            plt.savefig(path + '/enemy_hp.png')
        plt.close()

        plt.plot(np.arange(len(self.leftover_enemy_hp)),
                 self.leftover_enemy_hp)
        plt.ylabel('enemy hp')
        plt.xlabel('Episodes')
        if save:
            plt.savefig(path + '/eval.png')
        plt.close()

        print("AVG ENEMY HP LEFT",
              sum(self.leftover_enemy_hp) / len(self.leftover_enemy_hp))
        print("Winning Rate: {0:.2f}%".format(
            float(self.win / (self.episodes - 1) * 100)))
        print("Low hp controlled steps", self.count)

    # from the origin base.agent
    def setup(self, obs_spec, action_spec):
        self.obs_spec = obs_spec
        self.action_spec = action_spec

    # from the origin base.agent
    def reset(self):
        self.episodes += 1
        # added instead of original
        self.fighting = False
        if self.episodes > 1:
            self.leftover_enemy_hp.append(sum(self.previous_enemy_hp))
            if sum(self.previous_enemy_hp) == 0:
                self.win += 1
            self.ddpg.learn()
示例#8
0
def main(dic_agent_conf, dic_exp_conf, dic_env_conf, dic_path):
    np.random.seed(dic_agent_conf["NUMPY_SEED"])
    t = time.localtime(time.time())

    flag = False
    #train_show=[]
    if dic_exp_conf["AGENT_NAME"] == "DDPG":
        agent = DDPG(dic_agent_conf, dic_exp_conf, dic_path)
        flag = True
    print("=== Build Agent: %s ===" % dic_exp_conf["AGENT_NAME"])

    # ===== train =====
    print("=== Train Start ===")
    train_reward = []

    env = Env(dic_env_conf)
    for cnt_train_iter in range(dic_exp_conf["TRAIN_ITERATIONS"]):
        s = env.reset()
        r_sum = 0
        cnt_train_step = 0
        while (True):
            ##管理员先动作
            a = agent.choose_action(s, explore=True)
            action = 0
            if a >= 0.5:
                action = 1
            s_, r = env.step(action)
            r_sum += r
            train_reward.append(r)

            s_ = env.step_user()
            if (s_ is None):
                break

            if "DDPG" in dic_exp_conf["AGENT_NAME"]:
                agent.store_transition(s, a, r, s_)

            s = s_

            if "DDPG" in dic_exp_conf["AGENT_NAME"]:
                if agent.memory_batch_full:
                    agent.learn()

            cnt_train_step = cnt_train_step + 1
            if cnt_train_step % 100 == 0:
                with open('result.txt', 'a+') as f:
                    f.write(
                        "train: iter:{}, step:{}, r_sum:{},rewrd:{},action:{},successAttackers:{},foundAttackers:{}\n"
                        .format(cnt_train_iter, cnt_train_step, r_sum, r, a,
                                len(env.successAttackers),
                                len(env.foundAttackers)))
                print(s)
                print(
                    "train, step:{}, r_sum:{},rewrd:{},action:{},successAttackers:{},foundAttackers:{}"
                    .format(cnt_train_step, r_sum, r, a,
                            len(env.successAttackers),
                            len(env.foundAttackers)))

        print("train, iter:{}, r_sum:{}".format(cnt_train_iter, r_sum))
        train_reward.append(r_sum)

    if not dic_agent_conf["TRAIN"] and flag:
        agent.save_model(t)
    print("=== Train End ===")

    # ==== test ====
    print("=== Test Start ===")
    test_reward = []
    if dic_exp_conf["AGENT_NAME"]:
        test_com_cnt = []

    for cnt_test_iter in range(dic_exp_conf["TEST_ITERATIONS"]):
        s = env.reset()
        r_sum = 0

        cnt_test_step = 0

        while (True):
            ##管理员先动作
            a = agent.choose_action(s, explore=False)
            action = 0
            if a > 0.5:
                action = 1
            s_, r = env.step(action)

            s_ = env.step_user()
            if (s_ is None):
                break

        #for cnt_test_step in range(dic_exp_conf["MAX_EPISODE_LENGTH"]):
        #   a = agent.choose_action(s, explore=False)

        # action = 0
        # if a > 0.5:
        #     action = 1
        # s_, r = env.step(action)

        #s_, r = env.step(a)
            r_sum += r

            #test_reward.append(r)

            s = s_
            #if cnt_test_step%50==0:
            cnt_test_step = cnt_test_step + 1
            if cnt_test_step % 100 == 0:
                #
                print(
                    "test, step:{}, r_sum:{},rewrd:{},action:{},successAttackers:{},foundAttackers:{}"
                    .format(cnt_test_step, r_sum, r, a,
                            len(env.successAttackers),
                            len(env.foundAttackers)))

        if dic_exp_conf["AGENT_NAME"] == "ComDDPG":
            test_com_cnt.append(agent.communicate_counter)
        #print("test, iter:{}, r_sum:{}".format(cnt_test_iter, r))
        print("test, iter:{}, r_sum:{}".format(cnt_test_iter, r))
        test_reward.append(r_sum)

    print("=== Test End ===")
    # ==== record ====
    print("=== Record Begin ===")

    train_info = [train_reward]
    if dic_exp_conf["AGENT_NAME"] == "DDPG":
        test_info = [test_reward]
    else:
        test_info = [test_reward]

    plot(train_info, dic_agent_conf, dic_exp_conf, dic_path, "TRAIN", t)
    plot(test_info, dic_agent_conf, dic_exp_conf, dic_path, "TEST", t)
    record(train_info, test_info, dic_agent_conf, dic_exp_conf, dic_env_conf,
           dic_path, t)
    print("=== Record End ===")
示例#9
0
class Strategy(object):
    def __init__(self, team):
        self.sac_cal = sac_calculate()

        self.a = []
        self.s = []
        self.r = 0.0
        self.done = False

        self.avg_arr = np.zeros(64)
        self.team = team
        self.RadHead2Ball = 0.0
        self.RadHead = 0.0
        self.NunbotAPosX = 0.0
        self.NunbotAPosY = 0.0
        self.NunbotBPosX = 0.0
        self.BallPosX = 0.0
        self.BallPosY = 0.0
        self.GoalX = 900.0
        self.GoalY = 0.0
        self.StartX = -900.0
        self.StartY = 0.0
        self.kick_count = 0
        self.kick_num = 0
        self.score = 0
        self.RadTurn = 0.0
        self.Steal = False
        self.dis2start = 0.0
        self.dis2goal = 0.0
        self.vec = VelCmd()
        self.A_info = np.array([1.0, 1.0, 1.0, 1.0, 0, 0, 0, 0])
        self.game_count = 2
        self.A_z = 0.0
        self.B_z = 0.0
        self.HowEnd = 0
        self.B_dis = 0.0
        self.ep_rwd = 0
        self.is_kick = False
        self.ready2restart = True
        self.list_rate = list(np.zeros(128))
        self.milestone = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
        self.step_milestone = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
        self.milestone_idx = 0
        self.is_in = False
        self.is_out = False
        self.is_steal = False
        self.is_fly = False
        self.is_stealorfly = False
        self.real_resart = True
        self.step_count = 0

    def callback(self, data):  # Rostopic 之從外部得到的值
        self.RadHead2Ball = data.ballinfo.real_pos.angle
        self.RadHead = data.robotinfo[0].heading.theta
        self.BallPosX = data.ballinfo.pos.x
        self.BallPosY = data.ballinfo.pos.y
        self.NunbotAPosX = data.robotinfo[0].pos.x
        self.NunbotAPosY = data.robotinfo[0].pos.y
        self.NunbotBPosX = data.obstacleinfo.pos[0].x
        self.B_dis = data.obstacleinfo.polar_pos[0].radius

    def steal_callback(self, data):
        self.Steal = data.data

    def A_info_callback(self, data):
        self.A_info = np.array(data.data)
        self.is_kick = True

    def state_callback(self, data):
        self.kick_count = 0

    def reward_callback(self, data):
        pass
        # self.r = data.data
    def done_callback(self, data):
        self.done = data.data

    def fly_callback(self, data):
        self.A_z = data.pose[5].position.z
        self.B_z = data.pose[6].position.z

    def HowEnd_callback(self, data):
        self.HowEnd = data.data

    def ready2restart_callback(self, data):
        self.restart()
        self.ready2restart = False

    def ros_init(self):
        if self.team == 'A':
            self.agent = DDPG(act_dim=2,
                              obs_dim=12,
                              lr_actor=l_rate * (1e-3),
                              lr_q_value=l_rate * (1e-3),
                              gamma=0.99,
                              tau=0.995,
                              action_noise_std=1)

            rospy.init_node('strategy_node_A', anonymous=True)
            # self.A_info_pub = rospy.Publisher('/nubot1/A_info', Float32MultiArray, queue_size=1) # 3in1
            self.vel_pub = rospy.Publisher('/nubot1/nubotcontrol/velcmd',
                                           VelCmd,
                                           queue_size=1)
            self.reset_pub = rospy.Publisher('/gazebo/set_model_state',
                                             ModelState,
                                             queue_size=10)
            # self.ready2restart_pub  = rospy.Publisher('nubot1/ready2restart',Bool, queue_size=1)
            rospy.Subscriber("/nubot1/omnivision/OmniVisionInfo",
                             OminiVisionInfo, self.callback)
            rospy.Subscriber('gazebo/model_states', ModelStates,
                             self.fly_callback)
            # rospy.Subscriber('/coach/state', String, self.state_callback)
            # rospy.Subscriber('/coach/reward', Float32, self.reward_callback)
            # rospy.Subscriber('/coach/done', Bool, self.done_callback)
            # rospy.Subscriber('coach/HowEnd', Int16, self.HowEnd_callback)
            # rospy.Subscriber("/rival1/steal", Bool, self.steal_callback)

            rospy.wait_for_service('/nubot1/Shoot')
            self.call_Shoot = rospy.ServiceProxy('/nubot1/Shoot', Shoot)

            # rospy.wait_for_service('/gazebo/reset_simulation')
            # self.call_restart = rospy.ServiceProxy('/gazebo/reset_simulation', Empty, persistent=True)

            # rospy.wait_for_service('/gazebo/set_model_state')
            # self.call_set_modol = rospy.ServiceProxy('/gazebo/set_model_state', SetModelState)
            rospy.wait_for_service('/nubot1/BallHandle')
            self.call_Handle = rospy.ServiceProxy('/nubot1/BallHandle',
                                                  BallHandle)
            rospy.wait_for_service('/rival1/BallHandle')
            self.call_B_Handle = rospy.ServiceProxy('/rival1/BallHandle',
                                                    BallHandle)
        elif self.team == 'B':
            rospy.init_node('strategy_node_B', anonymous=True)
            self.vel_pub = rospy.Publisher('/rival1/nubotcontrol/velcmd',
                                           VelCmd,
                                           queue_size=1)
            self.steal_pub = rospy.Publisher('/rival1/steal',
                                             Bool,
                                             queue_size=1)  # steal
            rospy.Subscriber("/rival1/omnivision/OmniVisionInfo",
                             OminiVisionInfo, self.callback)
            rospy.wait_for_service('/rival1/BallHandle')
            self.call_Handle = rospy.ServiceProxy('/rival1/BallHandle',
                                                  BallHandle)

        else:
            rospy.init_node('coach', anonymous=True)
            self.state_pub = rospy.Publisher('/coach/state',
                                             String,
                                             queue_size=1)
            self.reward_pub = rospy.Publisher('/coach/reward',
                                              Float32,
                                              queue_size=1)
            self.done_pub = rospy.Publisher('coach/done', Bool, queue_size=1)
            self.HowEnd_pub = rospy.Publisher('coach/HowEnd',
                                              Int16,
                                              queue_size=1)
            rospy.Subscriber("/nubot1/omnivision/OmniVisionInfo",
                             OminiVisionInfo, self.callback)
            rospy.Subscriber("/rival1/steal", Bool,
                             self.steal_callback)  # steal
            rospy.Subscriber("/nubot1/A_info", Float32MultiArray,
                             self.A_info_callback)
            # rospy.Subscriber('gazebo/model_states', ModelStates, self.fly_callback)
            rospy.Subscriber('nubot1/ready2restart', Bool,
                             self.ready2restart_callback)
            rospy.wait_for_service('/gazebo/reset_simulation')
            self.call_restart = rospy.ServiceProxy('/gazebo/reset_simulation',
                                                   Empty)

    def ball_out(self):
        if self.BallPosX >= 875 or self.BallPosX <= -875 or self.BallPosY >= 590 or self.BallPosY <= -590:
            self.show('Out')
            self.is_out = True

    def ball_in(self):
        if self.BallPosX >= 870 and self.BallPosX <= 900 and self.BallPosY >= -100 and self.BallPosY <= 100:
            self.show('in')
            self.is_in = True

    def fly(self):
        if self.A_z > 0.34 or self.B_z > 0.34:
            self.is_fly = True

    def steal(self):
        rospy.wait_for_service('/nubot1/BallHandle')
        rospy.wait_for_service('/rival1/BallHandle')
        if self.call_B_Handle(
                1).BallIsHolding and not self.call_Handle(1).BallIsHolding:
            self.is_steal = True

    def stealorfly(self):
        if self.is_fly or self.is_steal:
            self.is_stealorfly = True

    def show(self, state):
        global _state
        if state != _state:
            print(state)
        _state = state

    def cnt_rwd(self):
        data = Float32MultiArray()
        data.data = [
            self.kick_count,
            self.cal_dis2start(),
            self.cal_dis2goal(), self.B_dis, 0, 0, 0, 0
        ]
        if self.game_is_done():
            if self.HowEnd == 1:
                data.data[4] = 1
                data.data[5] = 0
                data.data[6] = 0

            elif self.HowEnd == -1:
                data.data[4] = 0
                data.data[5] = 0
                data.data[6] = 1

            elif self.HowEnd == -2:
                data.data[4] = 0
                data.data[5] = 1
                data.data[6] = 0
        # self.A_info_pub.publish(data) # 2C
        self.sac_cal = sac_calculate()
        reward = self.sac_cal.reward(data.data)
        data.data[7] = reward
        print('rwd init',
              ['kck_n  g_dis st_dis  opp_dis  in    out   steal   ttl'])
        print('rwd unit', np.around((data.data), decimals=1))
        print('rwd :', reward)
        return (reward)

    def kick(self):
        global MaxSpd_A
        self.vec.Vx = MaxSpd_A * math.cos(self.RadHead2Ball)
        self.vec.Vy = MaxSpd_A * math.sin(self.RadHead2Ball)
        # self.vec.w = self.RadHead2Ball * RotConst
        self.vec.w = 0
        self.vel_pub.publish(self.vec)
        global pwr
        # rospy.wait_for_service('/nubot1/Shoot')
        self.call_Shoot(pwr, 1)  # power from GAFuzzy
        while 1:
            self.chase(MaxSpd_A)
            if self.game_is_done() and self.real_resart:
                break
            if not self.call_Handle(1).BallIsHolding:
                self.kick_count = self.kick_count + 1
                # time.sleep(0.2)
                print("Kick: %d" % self.kick_count)
                print('-------')
                break
            print('in')

    def chase(self, MaxSpd):
        self.vec.Vx = MaxSpd * math.cos(self.RadHead2Ball)
        self.vec.Vy = MaxSpd * math.sin(self.RadHead2Ball)
        self.vec.w = self.RadHead2Ball * RotConst
        self.vel_pub.publish(self.vec)

    def chase_B(self, MaxSpd):
        self.vec.Vx = MaxSpd * math.cos(self.RadHead2Ball)
        self.vec.Vy = MaxSpd * math.sin(self.RadHead2Ball)
        self.vec.w = self.RadHead2Ball * RotConst / 4  #######
        self.vel_pub.publish(self.vec)
        # self.show("Chasing")
    def turn(self, angle):
        global MaxSpd_A
        self.vec.Vx = MaxSpd_A * math.cos(self.RadHead2Ball)
        self.vec.Vy = MaxSpd_A * math.sin(self.RadHead2Ball)
        # self.vec.Vx = 0
        # self.vec.Vy = 0
        self.vec.w = turnHead2Kick(self.RadHead, angle) * RotConst
        self.vel_pub.publish(self.vec)
        self.show("Turning")

    def cal_dis2start(self):  # last kick
        dis2start_x = self.NunbotAPosX - self.StartX
        dis2start_y = self.NunbotAPosY - self.StartY
        dis2start = math.hypot(dis2start_x, dis2start_y)
        return dis2start
        # self.dis2start_pub.publish(dis2start)

    def cal_dis2goal(self):  # last kick
        dis2goal_x = self.NunbotAPosX - self.GoalX
        dis2goal_y = self.NunbotAPosY - self.GoalY
        dis2goal = math.hypot(dis2goal_x, dis2goal_y)
        return dis2goal
        # self.dis2goal_pub.publish(dis2goal)

    def avg(self, n, l):
        l = np.delete(l, 0)
        l = np.append(l, n)
        self.avg_arr = l
        print(self.avg_arr)
        print(sum(l) / 64)

    def reset_ball(self):
        while self.BallPosX != -680:
            Ball_msg = ModelState()
            Ball_msg.model_name = 'football'
            Ball_msg.pose.position.x = -6.8  #-6 #-6.8
            # Ball_msg.pose.position.y = random.uniform(-3.3,3.3)
            Ball_msg.pose.position.y = 0
            Ball_msg.pose.position.z = 0.12
            Ball_msg.pose.orientation.x = 0
            Ball_msg.pose.orientation.y = 0
            Ball_msg.pose.orientation.z = 0
            Ball_msg.pose.orientation.w = 1
            self.reset_pub.publish(Ball_msg)

    def reset_A(self):
        while self.NunbotAPosX != -830:
            A_msg = ModelState()
            A_msg.model_name = 'nubot1'
            A_msg.pose.position.x = -8.3  #-8 #-8.5
            # A_msg.pose.position.y = random.uniform(-1.7,1.7)
            A_msg.pose.position.y = 0
            A_msg.pose.position.z = 0
            A_msg.pose.orientation.x = 0
            A_msg.pose.orientation.y = 0
            A_msg.pose.orientation.z = 0
            A_msg.pose.orientation.w = 1
            self.reset_pub.publish(A_msg)

    def reset_B(self):
        while self.NunbotBPosX != 0:
            B_msg = ModelState()
            B_msg.model_name = 'rival1'
            # a = [[2,0],[1, -2],[1, 2],[0,4],[0,-4]]
            # b = a[random.randint(0, 4)]
            c = random.uniform(-5, 5)
            B_msg.pose.position.x = 0
            # B_msg.pose.position.y = c
            B_msg.pose.position.y = 0
            B_msg.pose.position.z = 0
            B_msg.pose.orientation.x = 0
            B_msg.pose.orientation.y = 0
            B_msg.pose.orientation.z = 0
            B_msg.pose.orientation.w = 1
            self.reset_pub.publish(B_msg)

    def restart(self):
        # game_state_word = "game is over"
        # self.state_pub.publish(game_state_word) # 2A
        # self.Steal = False
        self.reset_ball()
        # self.reset_ball()
        print('Game %d over' % (self.game_count - 1))
        print('-----------Restart-----------')
        print('Game %d start' % self.game_count)
        self.reset_A()
        # self.reset_A()
        self.game_count += 1
        self.kick_count = 0
        self.reset_B()
        # self.reset_B()
        # self.call_set_modol(SetModelState)

        # print('after call_restart')
        self.ready2restart = False
        self.is_fly = False
        self.is_steal = False
        self.is_stealorfly = False
        self.is_in = False
        self.is_out = False

        # print('i finish def restart(self)')
    def end_rate(self, end):
        self.list_rate[self.game_count % 128] = end
        out_count = self.list_rate.count(-2)
        in_count = self.list_rate.count(1)
        steal_count = self.list_rate.count(-1)
        print('in_rate', in_count / 128, 'out_rate', out_count / 128,
              'steal_rate', steal_count / 128)
        if in_count / 128 != 0 and self.milestone_idx == 0:
            self.milestone[0] = self.game_count
            self.step_milestone[0] = self.step_count
            self.milestone_idx = self.milestone_idx + 1
        if in_count / 128 >= 0.1 and self.milestone_idx == 1:
            self.milestone[1] = self.game_count
            self.step_milestone[1] = self.step_count
            self.milestone_idx = self.milestone_idx + 1
        if in_count / 128 >= 0.2 and self.milestone_idx == 2:
            self.milestone[2] = self.game_count
            self.step_milestone[2] = self.step_count
            self.milestone_idx = self.milestone_idx + 1
        if in_count / 128 >= 0.3 and self.milestone_idx == 3:
            self.milestone[3] = self.game_count
            self.step_milestone[3] = self.step_count
            self.milestone_idx = self.milestone_idx + 1
        if in_count / 128 >= 0.4 and self.milestone_idx == 4:
            self.milestone[4] = self.game_count
            self.step_milestone[4] = self.step_count
            self.milestone_idx = self.milestone_idx + 1
        if in_count / 128 >= 0.5 and self.milestone_idx == 5:
            self.milestone[5] = self.game_count
            self.step_milestone[5] = self.step_count
            self.milestone_idx = self.milestone_idx + 1
        if in_count / 128 >= 0.6 and self.milestone_idx == 6:
            self.milestone[6] = self.game_count
            self.step_milestone[6] = self.step_count
            self.milestone_idx = self.milestone_idx + 1
        if in_count / 128 >= 0.7 and self.milestone_idx == 7:
            self.milestone[7] = self.game_count
            self.step_milestone[7] = self.step_count
            self.milestone_idx = self.milestone_idx + 1
        if in_count / 128 >= 0.8 and self.milestone_idx == 8:
            self.milestone[8] = self.game_count
            self.step_milestone[8] = self.step_count
            self.milestone_idx = self.milestone_idx + 1
        if in_count / 128 >= 0.9 and self.milestone_idx == 9:
            self.milestone[9] = self.game_count
            self.step_milestone[9] = self.step_count
            self.milestone_idx = self.milestone_idx + 1
        if in_count / 128 >= 1.0 and self.milestone_idx == 10:
            self.milestone[10] = self.game_count
            self.step_milestone[10] = self.step_count
            self.milestone_idx = self.milestone_idx + 1
        # if in_count/128 >= 0.5 and self.milestone_idx ==3:
        #     self.milestone[3]=self.game_count
        #     self.step_milestone[3]=self.step_count
        #     self.milestone_idx = self.milestone_idx +1
        # if in_count/128 >= 0.8 and self.milestone_idx ==4:
        #     self.milestone[4]=self.game_count
        #     self.step_milestone[4]=self.step_count
        #     self.milestone_idx = self.milestone_idx +1
        # if in_count/128 >= 0.9 and self.milestone_idx ==5:
        #     self.milestone[5]=self.game_count
        #     self.step_milestone[5]=self.step_count
        #     self.milestone_idx = self.milestone_idx +1
        # if in_count/128 == 1  and self.milestone_idx ==6:
        #     self.milestone[6]=self.game_count
        #     self.step_milestone[6]=self.step_count
        #     self.milestone_idx = self.milestone_idx +1
        print('milestone', self.milestone)
        print('milestone', self.step_milestone)

    def game_is_done(self):
        self.ball_in()
        self.ball_out()
        self.steal()
        self.fly()
        self.stealorfly()
        if self.is_in or self.is_out or self.is_stealorfly:
            if self.is_in:
                self.HowEnd = 1
            elif self.is_out:
                self.HowEnd = -2
            elif self.is_stealorfly:
                self.HowEnd = -1
            else:
                print('err')
            return True
        else:
            return False

    def workA(self):
        np.set_printoptions(suppress=True)
        i = 0
        fisrt_time_hold = False
        while not rospy.is_shutdown():
            # print(self.ball_in(), self.ball_out(), self.stealorfly())
            rospy.wait_for_service('/nubot1/BallHandle')
            self.call_Handle(1)  # open holding device
            if self.game_is_done() and self.real_resart:
                # print('self.game_is_done()',self.game_is_done())
                self.r = self.cnt_rwd()
                # print('h',self.HowEnd)
                s_ = self.sac_cal.input(self.HowEnd)  #out state
                if i > 1:
                    # print ('12?')
                    # print (len(self.s))
                    if len(self.s) == 12 and len(s_) == 12:
                        # print('000000000000000000',np.shape(self.s), np.shape(self.a))
                        self.agent.memory.store_transition(
                            self.s, self.a, self.r, s_, self.done)
                    # print('d',self.done)
                    self.ep_rwd = self.r
                    print('ep rwd value=', self.r)
                    self.end_rate(self.HowEnd)
                self.step_count = self.step_count + 1
                i += 1
                self.s = s_
                self.done = False
                if i > 512:
                    self.agent.learn(i, self.r, self.ep_rwd)
                # self.ready2restart_pub.publish(True)
                # self.ready2restart_pub.publish(False)
                self.real_resart = False
                self.HowEnd = 0
                # print('i want to go in self.restart()')
                self.restart()
                # self.end_rate(self.HowEnd)
                # print('---')
            # elif not self.game_is_done():

            else:
                # print('self.game_is_done()',self.game_is_done())
                rospy.wait_for_service('/nubot1/BallHandle')
                if not self.call_Handle(1).BallIsHolding:  # BallIsHolding = 0
                    self.chase(MaxSpd_A)
                    fisrt_time_hold = True
                    rospy.wait_for_service('/nubot1/BallHandle')
                    # do real reset before holding
                    self.ready2restart = False
                    self.is_fly = False
                    self.is_steal = False
                    self.is_stealorfly = False
                    self.is_in = False
                    self.is_out = False
                    self.real_resart = True  #
                elif self.call_Handle(1).BallIsHolding:  # BallIsHolding = 1
                    global RadHead
                    self.chase(MaxSpd_A)
                    if fisrt_time_hold == True:
                        self.real_resart = True  #
                        self.chase(MaxSpd_A)
                        self.show('Touch')
                        self.r = self.cnt_rwd()
                        s_ = self.sac_cal.input(0)  #state_for_sac
                        if i >= 1:
                            if len(self.s) == 12 and len(s_) == 12:
                                self.agent.memory.store_transition(
                                    self.s, self.a, self.r, s_, self.done)
                            print('step rwd value= ', self.r)
                        self.step_count = self.step_count + 1
                        self.done = False
                        i += 1
                        self.s = s_
                        self.a = self.agent.choose_action(self.s,
                                                          )  #action_from_sac
                        rel_turn_ang = self.sac_cal.output(
                            self.a)  #action_from_sac

                        global pwr, MAX_PWR
                        pwr = (self.a[1] + 1) * MAX_PWR / 2 + 1.4  #normalize
                        # sac]

                        rel_turn_rad = math.radians(rel_turn_ang)
                        self.RadTurn = rel_turn_rad + self.RadHead
                        fisrt_time_hold = False
                        if i > 512:
                            self.agent.learn(i, self.r, self.ep_rwd)
                    elif fisrt_time_hold == False:
                        self.chase(MaxSpd_A)
                        error = math.fabs(
                            turnHead2Kick(self.RadHead, self.RadTurn))
                        if error > angle_thres:  # 還沒轉到
                            self.turn(self.RadTurn)
                        else:  # 轉到
                            self.kick()

    def workB(self):
        # catch = False
        while not rospy.is_shutdown():
            rospy.wait_for_service('/rival1/BallHandle')
            # self.call_Handle(1) #start holding device
            if not self.call_Handle(1).BallIsHolding:  # BallIsHolding = 0
                self.steal_pub.publish(False)
                self.chase_B(MaxSpd_B)
                # catch = False
            else:  # BallIsHolding = 1
                # self.chase(MaxSpd_B/4)
                # if not catch:
                # catch = True
                # ticks = time.time()
                # ticks = ticks + 1 # sec # steal time
                # if time.time() > ticks:
                self.steal_pub.publish(True)  # 2C
            # self.show('steal')

            # ticks += 5

    def coach(self):
        pass

    def workC(self):
        print('Game 1 start')
        # rate = rospy.Rate(10)
        np.set_printoptions(suppress=True)
        while not rospy.is_shutdown():

            is_stealorfly = self.stealorfly()
            is_out = self.ball_out()
            is_in = self.ball_in()

            # # [] send rwd 2 A

            # self.sac_cal = sac_calculate()
            # # self.A_info = list(self.A_info)

            # # if self.is_kick:
            # self.A_info[4] = 0
            # self.A_info[5] = 0
            # self.A_info[6] = 0

            # reward = self.sac_cal.reward(self.A_info)   # rwd 2 A
            # self.reward_pub.publish(reward)
            # print('step rwd unit = ', np.around((self.A_info), decimals=1 )) # 7in1 # 7 rwd unit
            # print('step rwd value =',reward)
            #     # self.is_kick = False

            if is_in or is_out or is_stealorfly:
                # done 2 A
                # self.ready2restart = False
                if is_in:
                    HowEnd = 1
                #     self.A_info[4] = 1
                #     self.A_info[5] = 0
                #     self.A_info[6] = 0
                if is_stealorfly:
                    HowEnd = -1
                #     self.A_info[4] = 0
                #     self.A_info[5] = 0
                #     self.A_info[6] = 1
                if is_out:
                    HowEnd = -2
                #     self.A_info[4] = 0
                #     self.A_info[5] = 1
                #     self.A_info[6] = 0
                self.HowEnd_pub.publish(HowEnd)
                self.done_pub.publish(True)
                if self.ready2restart:
                    self.restart()
示例#10
0
for i in range(len(seeds)):
    agent = DDPG(a_dim=6, s_dim=17, a_bound=1, lr_a=0.0001, lr_c=0.001, seed=seeds[i], namespace='ddpg_' + str(i))
    exploration_rate = 0.2
    np.random.seed(seeds[i])
    env.seed(seeds[i])
    tf.set_random_seed(seeds[i])

    total_reward = []
    for episode in range(1000):
        state = env.reset()
        var = 0.1
        cum_reward = 0
        for step in range(1000):
            # action = np.clip(np.random.normal(np.reshape(agent.choose_action(state), [6, ]), var), -1, 1)
            if np.random.uniform() > exploration_rate:
                action = np.clip(np.random.normal(np.reshape(agent.choose_action(state), [6, ]), var), -1, 1)
            else:
                action = env.action_space.sample()
            next_state, reward, done, _ = env.step(action)
            # print(action)
            cum_reward += reward
            agent.store_transition(state, action, reward, next_state, done)
            state = next_state
            agent.learn()
            if done:
                print('Episode', episode, ' Complete at reward ', cum_reward, '!!!')
                # print('Final velocity x is ',state[9])
                # print('Final velocity z is ',state[10])
                break
            if step == 1000 - 1:
                print('Episode', episode, ' finished at reward ', cum_reward)
ddpg = DDPG(s_dim, a_dim, a_bound)

var = 3  # control exploration
t1 = time.time()

for i in range(MAX_EPISODES):
    s = env.reset()
    #print('s1:',s)
    ep_reward = 0

    for j in range(MAX_EP_STEPS):
        if RENDER:
            env.render()

        # Add exploration noise
        a = ddpg.choose_action(s)
        a = np.clip(np.random.normal(a, var), -2,
                    2)  # add randomness to action selection for exploration
        s_, r, done, info = env.step(a)

        ddpg.store_transition(s, a, r / 10, s_, done)

        if ddpg.memory_count == ddpg.memory_size:
            var *= 0.9995  # decay the action randomness
            ddpg.learn()

        s = s_
        ep_reward += r

        if j == MAX_EP_STEPS - 1:
            print(
def run(env_name: str, algorithm: str, max_ep, max_ep_step):
    env = gym.make(env_name)
    # env = env.unwrapped
    env.seed(1)

    s_dim = env.observation_space.shape[0]
    a_dim = env.action_space.shape[0]
    a_bound = env.action_space.high[0]
    path = (r"saved model/" + env_name[:-3].lower() + "/" + algorithm + "/" +
            env_name[:-3].lower() + ".ckpt")
    if algorithm == 'ddpg':
        from ddpg import DDPG
        agent = DDPG(a_dim,
                     s_dim,
                     a_bound,
                     lr_a=1e-4,
                     lr_c=2e-4,
                     var_decay=.9999,
                     var=0.5,
                     batch_size=128,
                     graph=False,
                     memory_capacity=4000)
        saver = tf.train.Saver()
        saver.restore(agent.sess, path)

        for i in range(max_ep):
            observation = env.reset()
            ep_reward = 0

            for step in range(max_ep_step):
                env.render()
                # False if testing
                act = agent.choose_action(observation, explore=False)

                observation_, reward, done, _ = env.step(act)
                ep_reward += reward

                # commented if eval
                # agent.store_transition(observation, act, reward, observation_)

                if done:
                    print(step, "steps")
                    break

                observation = observation_
            # if i >= 50:
            #     ep_summary = tf.Summary(value=[tf.Summary.Value(tag="ep_reward", simple_value=ep_reward)])
            #     agent.writer.add_summary(ep_summary, i - 50)
            print('Episode:', i, ' Reward: %f' % ep_reward,
                  'Explore: %f' % agent.var)
            if interrupt:
                break
        saver.save(agent.sess, path, write_meta_graph=False)

        agent.sess.close()

    elif algorithm == 'ppo':
        from baselines.common import tf_util as U

        def train(num_timesteps=10000):
            from baselines.ppo1 import mlp_policy, pposgd_simple
            U.make_session(num_cpu=1).__enter__()

            def policy_fn(name, ob_space, ac_space):
                return mlp_policy.MlpPolicy(name=name,
                                            ob_space=ob_space,
                                            ac_space=ac_space,
                                            hid_size=64,
                                            num_hid_layers=2)

            pposgd_simple.learn(
                env,
                policy_fn,
                max_timesteps=num_timesteps,
                timesteps_per_actorbatch=2048,
                clip_param=0.2,
                entcoeff=0.0,
                optim_epochs=10,
                optim_stepsize=3e-4,
                optim_batchsize=64,
                gamma=0.99,
                lam=0.95,
                schedule='linear',
            )

        train(env)
    env.close()