Esempio n. 1
0
def main():
    map_dir = "/home/i2rlab/shahil_files/shahil_RL_ws_new/src/turtlebot/turtlebot_gazebo/worlds/"
    #os.rename(map_dir+"house_train.world",map_dir+"maze.world")
    rospy.init_node('turtlebot2_human', anonymous=True, log_level=rospy.WARN)
    env = StartOpenAI_ROS_Environment(
            'MyTurtleBot2HumanModel-v1')
    #os.rename(map_dir+"maze.world",map_dir+"house_train.world")
    s_dim = env.observation_space.shape[0]+4
    a_dim = env.action_space.shape[0]
    a_bound = env.action_space.high
    #print('a_bound test', a_bound)
    sign_talker = rospy.Publisher('/sign', Int64, queue_size=1)
    ####################### load agent #######################################
    #TD3 = td3(a_dim, s_dim, a_bound, GAMMA, TAU, MEMORY_CAPACITY, BATCH_SIZE, LR_A, LR_C)
    #TD3.loader()

    d = [0.5,1,0.5,1]+list(6*np.ones([180]))
    #A = np.array([[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]])
    #TGREEN =  '\033[32m' # Green Text
    #ENDC = '\033[m' # reset to the defaults

    #################### Start ####################
    i=0
    while i <2:
        s = env.reset()
        #print("env printing: ",env)
        t1 = time.time()
        min_distance_ob = s[-2]
        theta_bt_uh_ob = s[-1]
        #a_old = A[0]
        a_old = np.array([1,0,0,0])
        s = np.hstack((a_old,np.array(s[3:-2])/d))
        for j in range(MAX_EP_STEPS):
            ################ Mode #########################################
            #if i%2==0:
                #a = TD3.choose_action(s,1)
                #a = [1,0,0,0]
                #if a[1]<=0.9 and min_distance_ob<=0.7 and abs(theta_bt_uh_ob)<1.3:
                    #print (TGREEN + "Going to Crash!" , ENDC)
                    #a = A[1]
            '''else:
                a = A[0]'''

            #x_test = env.state_msg.pose.position.x
            #print('x_test',env.collect_data)

            a = np.array([1,0,0,0])
            s_, r, done, info = env.step(a)
            sign_talker.publish(int(r))
            if r <-200:
                done = 0
            min_distance_ob = s_[-2]
            theta_bt_uh_ob = s_[-1]
            a_old = a
            s=np.hstack((a_old,np.array(s_[3:-2])/d))
            if done or j == MAX_EP_STEPS-1:
                t = time.time()-t1
                if j>3:
                    i=i+1
                break
    def test(self):
        # Initialize OpenAI_ROS ENV
        LoadYamlFileParamsTest(rospackage_name="learning_ros", rel_path_from_package_to_file="config", yaml_file_name="aliengo_stand.yaml")
        env = StartOpenAI_ROS_Environment('AliengoStand-v0')
        saver = tf.train.Saver()

        time.sleep(3)
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            saver.restore(sess, 'model_test/model_ddqn.ckpt')
            scores = []
            success_num = 0

            for e in range(self.EPISODES):
                state = env.reset()
                state = np.reshape(state, [1, self.state_size])
                done = False
                i = 0
                rewards = []
                while  not rospy.is_shutdown():  # until ros is not shutdown                #self.env.render()
                    # height prob is chosen as action
                    action = np.argmax(self.model.predict(state))
                    next_state, reward, done, _ = env.step(action)
                    time.sleep(0.01)

                    next_state = np.reshape(next_state, [1, self.state_size])
                    rewards.append(reward)
                    self.remember(state, action, reward, next_state, done)
                    
                    i += 1
                    if done:
                        # every step update target model
                        self.update_target_model()
                        

                        state = env.reset()
                        break 
                    else:
                        state = next_state   
                    self.replay()
                scores.append(sum(rewards))
               
        plt.plot(scores)
        plt.show()        
Esempio n. 3
0
            qlearn.epsilon *= epsilon_discount

        # Initialize the environment and get first state of the robot
        observation = env.reset()
        state = ''.join(map(str, observation))

        # Show on screen the actual situation of the robot
        # env.render()
        # for each episode, we test the robot for nsteps
        for i in range(nsteps):
            rospy.logwarn("############### Start Step=>" + str(i))
            # Pick an action based on the current state
            action = qlearn.chooseAction(state)
            rospy.logwarn("Next action is:%d", action)
            # Execute the action in the environment and get feedback
            observation, reward, done, info = env.step(action)

            rospy.logwarn(str(observation) + " " + str(reward))
            cumulated_reward += reward
            if highest_reward < cumulated_reward:
                highest_reward = cumulated_reward

            nextState = ''.join(map(str, observation))

            # Make the algorithm learn based on the results
            rospy.logwarn("# state we were=>" + str(state))
            rospy.logwarn("# action that we took=>" + str(action))
            rospy.logwarn("# reward that action gave=>" + str(reward))
            rospy.logwarn("# episode cumulated_reward=>" +
                          str(cumulated_reward))
            rospy.logwarn(
Esempio n. 4
0
                       image_size=IM_SIZE)

    image_transformer = ImageTransformer(IM_SIZE)

    with tf.Session() as sess:
        model.set_session(sess)
        target_model.set_session(sess)
        #model.load()
        #target_model.load()
        sess.run(tf.global_variables_initializer())
        print("Initializing experience replay buffer...")
        obs = env.reset()

        for i in range(MIN_EXPERIENCE):
            action = np.random.choice(K)
            obs, reward, done, _ = env.step(action)
            obs_small = image_transformer.transform(obs, sess)
            experience_replay_buffer.add_experience(action, obs_small, reward,
                                                    done)
            if done:
                obs = env.reset()

        print("Done! Starts Training!!")
        t0 = datetime.now()
        for i in range(num_episodes):
            msg_data = Int16()
            msg_data.data = i
            episode_counter_pub.publish(msg_data)
            total_t, episode_reward, duration, num_steps_in_episode, time_per_step, epsilon = play_ones(
                env, sess, total_t, experience_replay_buffer, model,
                target_model, image_transformer, gamma, batch_sz, epsilon,
         env,
         verbose=1,
         tensorboard_log="../results/tensorboard_logs/PPO2/"),
    TRPO(MlpPolicy,
         env,
         verbose=1,
         tensorboard_log="../results/tensorboard_logs/TRPO/"),
]

algo_list = ['A2C', 'ACKTR', 'PPO2', 'TRPO']

# TEST
# model_list = [model_list[5], model_list[5]]
# algo_list = [algo_list[5], algo_list[5]]

rate = rospy.Rate(30)

for model, algo in zip(model_list, algo_list):
    print(algo)
    model = model.load("../results/trained_models/" + algo)
    obs = env.reset()

    for i in range(500):
        if i % 100 == 0:
            print(i)
        action, _states = model.predict(obs)
        obs, reward, done, info = env.step(action)
        rate.sleep()

env.close()
Esempio n. 6
0
        state = [round(num, 1) for num in state]
        list_state = state
        #print("\n type(state): ")
        #print(type(state))
        rospy.logwarn("# state we are => " + str(state))
        state = torch.from_numpy(np.array(state)).float().unsqueeze(0).to(device)
        while not done:
            rospy.logwarn("i_episode: " + str(i_episode))
            rospy.logwarn("step_count: " + str(step_count))
            # Select an action
            eps_greedy_threshold = compute_eps_threshold(step_count, eps_start, eps_end, eps_decay)
            action, policy_act, policy_used = select_action(policy_net, state, device, env, eps_greedy_threshold, n_actions)
            rospy.logwarn("Next action is:%d", action)

            # Perform action in env
            next_state, reward, done, _ = env.step(action.item())
            next_state = [round(num, 1) for num in next_state]
            list_next_state = next_state
            #rospy.logwarn(str(next_state) + " " + str(reward))

            # Bookkeeping
            next_state = torch.from_numpy(np.array(next_state)).float().unsqueeze(0).to(device)
            #reward = reward_shaper(reward, done)
            reward = torch.tensor([reward], device=device)
            step_count += 1
            

            # Store the transition in memory
            memory.push(state, action, next_state, reward)
            memory.push_trace(i_episode, step_count, list_state, action.item(), list_next_state, reward.item(), policy_act.item(), eps_greedy_threshold, policy_used)
Esempio n. 7
0
    def run(self):
        # Initialize OpenAI_ROS ENV
        LoadYamlFileParamsTest(rospackage_name="learning_ros",
                               rel_path_from_package_to_file="config",
                               yaml_file_name="aliengo_stand.yaml")
        env = StartOpenAI_ROS_Environment('AliengoStand-v0')
        saver = tf.train.Saver()

        time.sleep(3)
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            scores = []
            success_num = 0

            for e in range(self.EPISODES):
                state = env.reset()
                state = np.reshape(state, [1, self.state_size])
                done = False
                rewards = []
                while not rospy.is_shutdown(
                ):  # until ros is not shutdown                #self.env.render()
                    action = self.act(state)
                    next_state, reward, done, _ = env.step(action)
                    time.sleep(0.01)

                    next_state = np.reshape(next_state, [1, self.state_size])
                    rewards.append(reward)
                    self.remember(state, action, reward, next_state, done)

                    if done:
                        # every step update target model
                        self.update_target_model()

                        state = env.reset()
                        break
                    else:
                        state = next_state
                    self.replay()
                    #if consectuvely 10 times has high reward end the training
                scores.append(sum(rewards))
                if sum(rewards) >= 400:
                    success_num += 1
                    print("Succes number: " + str(success_num))
                    if success_num >= 5:  #checkpoint
                        if (ddqn):
                            saver.save(sess, 'model_train/model_ddqn.ckpt')
                        else:
                            saver.save(sess, 'model_train/model_dqn.ckpt')

                        print('Clear!! Model saved.')
                    if success_num >= 10:
                        if (ddqn):
                            saver.save(sess, 'model_train/model_ddqn.ckpt')
                        else:
                            saver.save(sess, 'model_train/model_dqn.ckpt')

                        print('Clear!! Model saved. AND Finished! ')
                        break

                else:
                    success_num = 0
        plt.plot(scores)
        plt.show()
def main():
    name = input('Please input your exp number:')
    folder = "Human_exp_new_" + str(name)
    os.system("mkdir " + folder)

    ####################### inite environment ########################################
    map_dir = "/home/i2rlab/shahil_files/shahil_RL_ws_new/src/turtlebot/turtlebot_gazebo/worlds/"
    #os.rename(map_dir+"house_3.world",map_dir+"maze.world")
    rospy.init_node('turtlebot2_human', anonymous=True, log_level=rospy.WARN)
    env = StartOpenAI_ROS_Environment('MyTurtleBot2HumanModel-v1')
    #os.rename(map_dir+"maze.world",map_dir+"house_3.world")
    s_dim = env.observation_space.shape[0] + 4
    a_dim = env.action_space.shape[0]
    a_bound = env.action_space.high
    sign_talker = rospy.Publisher('/sign', Int64, queue_size=1)
    ####################### load agent #######################################
    TD3 = td3(a_dim, s_dim, a_bound, GAMMA, TAU, MEMORY_CAPACITY, BATCH_SIZE,
              LR_A, LR_C)
    TD3.loader()
    # ddpg = DDPG(a_dim, s_dim, a_bound, GAMMA, TAU, MEMORY_CAPACITY, BATCH_SIZE, LR_A, LR_C)
    # with ddpg.sess.as_default():
    #     with ddpg.g.as_default():
    #         ddpg.loader()
    d = [0.5, 1, 0.5, 1] + list(6 * np.ones([180]))
    A = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
    TGREEN = '\033[32m'  # Green Text
    ENDC = '\033[m'  # reset to the defaults

    #################### Initialize Data #################################
    n_h = 0  ############## number of eps of pure human control #################
    n_s_TD3 = 0  ############## number of eps of shared control #################
    # n_s_ddpg = 0

    n_crash_h = 0  ############## number of crash in pure human control #################
    n_crash_s_TD3 = 0  ############## number of crash in shared control #################
    # n_crash_s_ddpg = 0

    n_success_h = 0  ############## number of success of pure human control #################
    n_success_s_TD3 = 0  ############## number of success of shared control #################
    # n_success_s_ddpg = 0

    t_h = 0  ############## total time of pure human control #################
    t_s_TD3 = 0  ############## total time of shared control #################
    # t_s_ddpg = 0
    #################### SART #####################
    SA_h = 0
    SA_s_TD3 = 0
    # SA_s_ddpg = 0
    #################### workload #################
    WL_h = 0
    WL_s_TD3 = 0
    # WL_s_ddpg = 0
    #################### Start ####################
    i = 0
    total_ep = 2
    sequence_ep = [0, 1]
    #np.random.shuffle(sequence_ep)
    pureh_ep = sequence_ep[0:int(total_ep / 2)]
    # ddpg_ep = sequence_ep[int(total_ep/3):int(2*total_ep/3)]
    td3_ep = sequence_ep[int(total_ep / 2):]
    n_d = 0
    while i < total_ep:
        s = env.reset()
        print(s)
        t1 = time.time()
        min_distance_ob = s[-2]
        theta_bt_uh_ob = s[-1]
        a_old = A[0]
        s = np.hstack((a_old, np.array(s[3:-2]) / d))
        done = 0
        while not done:
            ################ Mode #########################################
            if i in td3_ep:
                a = TD3.choose_action(s, 1)
                if a[1] <= 0.9 and min_distance_ob <= 0.7 and abs(
                        theta_bt_uh_ob) < 1.3:
                    print(TGREEN + "Going to Crash!", ENDC)
                    a = A[1]
            elif i in pureh_ep:
                a = A[0]

            s_, r, done, info = env.step(a)
            sign_talker.publish(int(r))
            if r < -400 and done:
                done = 0
                if i in td3_ep:
                    n_crash_s_TD3 += 1
                elif i in pureh_ep:
                    n_crash_h += 1
            print("n_crash_s_TD3:", n_crash_s_TD3)
            print("n_crash_h:", n_crash_h)
            print("i:", i)
            print("n_d:", n_d)
            min_distance_ob = s_[-2]
            theta_bt_uh_ob = s_[-1]
            a_old = a
            s = np.hstack((a_old, np.array(s_[3:-2]) / d))
            if done:
                t = 0
                n_d = n_d + 1
                if n_d % 5 == 0:
                    instability = int(
                        input(
                            'How changeable is the situation (1: stable and straightforward -- 7: changing suddenly):'
                        ))
                    variability = int(
                        input(
                            'How many variables are changing within the situation: (1: very few -- 7: large number):'
                        ))
                    complexity = int(
                        input(
                            'How complicated is the situation (1: simple -- 7: complex):'
                        ))
                    arousal = int(
                        input(
                            'How aroused are you in the situation (1: low degree of alertness -- 7: alert and ready for activity):'
                        ))
                    spare = int(
                        input(
                            'How much mental capacity do you have to spare in the situation (1: Nothing to spare -- 7: sufficient):'
                        ))
                    concentration = int(
                        input(
                            'How much are you concentrating on the situation (1: focus on only one thing -- 7: many aspect):'
                        ))
                    attention = int(
                        input(
                            'How much is your attention divide in the situation (1: focus on only one thing -- 7: many aspect):'
                        ))
                    quantity = int(
                        input(
                            'How much information have your gained about the situation (1: little -- 7: much):'
                        ))
                    quality = int(
                        input(
                            'How good information have you been accessible and usable (1: poor -- 7: good):'
                        ))
                    famlilarity = int(
                        input(
                            'How familar are you with the situation (1: New situation -- 7: a great deal of relevant experience):'
                        ))
                    SA = quantity + quality + famlilarity - (
                        (instability + variability + complexity) -
                        (arousal + spare + concentration + attention))
                    if i in td3_ep:
                        SA_s_TD3 += SA
                        WL_s_TD3 += float(
                            input('Please input your workload (from TLX):'))
                        t_s_TD3 += t
                        n_s_TD3 += 1
                        if r > 500:
                            n_success_s_TD3 += 1
                    elif i in pureh_ep:
                        SA_h += SA
                        WL_h += float(
                            input('Please input your workload (from TLX):'))
                        t_h += t
                        n_h += 1
                        if r > 500:
                            n_success_h += 1
                    i = i + 1
                break

    np.savetxt(
        'data.dat',
        np.array([[
            n_s_TD3, t_s_TD3, n_crash_s_TD3, n_success_s_TD3, SA_s_TD3,
            WL_s_TD3
        ], [n_h, t_h, n_crash_h, n_success_h, SA_h, WL_h]]))
    #,[n_s_ddpg,t_s_ddpg,n_crash_s_ddpg,n_success_s_ddpg, SA_s_ddpg, WL_s_ddpg]]))
    ########### shared_TD3: number of eps, total time, time of crash, time of success, situation awareness, workload ###########
    ########### human: number of eps, total time, time of crash, time of success, situation awareness, workload ###########
    ########### shared_TD3_ddpg: number of eps, total time, time of crash, time of success, situation awareness, workload ###########
    os.system('cp -r data.dat ' + folder + '/')
Esempio n. 9
0
class DQNRobotSolver():
    def __init__(self,
                 environment_name,
                 n_observations,
                 n_actions,
                 n_episodes=1000,
                 n_win_ticks=195,
                 min_episodes=100,
                 max_env_steps=None,
                 gamma=1.0,
                 epsilon=1.0,
                 epsilon_min=0.01,
                 epsilon_log_decay=0.995,
                 alpha=0.01,
                 alpha_decay=0.01,
                 batch_size=64,
                 monitor=False,
                 quiet=False):
        self.memory = deque(maxlen=100000)

        # self.env = gym.make(environment_name)
        self.env = StartOpenAI_ROS_Environment(environment_name)

        if monitor:
            self.env = gym.wrappers.Monitor(self.env,
                                            '../data/cartpole-1',
                                            force=True)

        self.input_dim = n_observations
        self.n_actions = n_actions
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_log_decay
        self.alpha = alpha
        self.alpha_decay = alpha_decay
        self.n_episodes = n_episodes
        self.n_win_ticks = n_win_ticks
        self.min_episodes = min_episodes
        self.batch_size = batch_size
        self.quiet = quiet
        if max_env_steps is not None:
            self.env._max_episode_steps = max_env_steps

        # Init model
        self.model = Sequential()

        self.model.add(Dense(24, input_dim=self.input_dim, activation='tanh'))
        self.model.add(Dense(48, activation='tanh'))
        self.model.add(Dense(self.n_actions, activation='linear'))
        self.model.compile(loss='mse',
                           optimizer=Adam(lr=self.alpha,
                                          decay=self.alpha_decay))

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def choose_action(self, state, epsilon):
        return self.env.action_space.sample() if (
            np.random.random() <= epsilon) else np.argmax(
                self.model.predict(state))

    def get_epsilon(self, t):
        return max(
            self.epsilon_min,
            min(self.epsilon, 1.0 - math.log10((t + 1) * self.epsilon_decay)))

    def preprocess_state(self, state):
        return np.reshape(state, [1, self.input_dim])

    def replay(self, batch_size):
        x_batch, y_batch = [], []
        minibatch = random.sample(self.memory, min(len(self.memory),
                                                   batch_size))
        for state, action, reward, next_state, done in minibatch:
            y_target = self.model.predict(state)
            y_target[0][
                action] = reward if done else reward + self.gamma * np.max(
                    self.model.predict(next_state)[0])
            x_batch.append(state[0])
            y_batch.append(y_target[0])

        self.model.fit(np.array(x_batch),
                       np.array(y_batch),
                       batch_size=len(x_batch),
                       verbose=0)
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

    def run(self):

        rate = rospy.Rate(30)

        scores = deque(maxlen=100)

        for e in range(self.n_episodes):

            init_state = self.env.reset()

            state = self.preprocess_state(init_state)
            done = False
            i = 0
            while not done:
                # openai_ros doesnt support render for the moment
                #self.env.render()
                action = self.choose_action(state, self.get_epsilon(e))
                next_state, reward, done, _ = self.env.step(action)
                next_state = self.preprocess_state(next_state)
                self.remember(state, action, reward, next_state, done)
                state = next_state
                i += 1

            scores.append(i)
            mean_score = np.mean(scores)
            if mean_score >= self.n_win_ticks and e >= min_episodes:
                if not self.quiet:
                    print('Ran {} episodes. Solved after {} trials'.format(
                        e, e - min_episodes))
                return e - min_episodes
            if e % 1 == 0 and not self.quiet:
                print(
                    '[Episode {}] - Mean survival time over last {} episodes was {} ticks.'
                    .format(e, min_episodes, mean_score))

            self.replay(self.batch_size)

        if not self.quiet: print('Did not solve after {} episodes'.format(e))
        return e
Esempio n. 10
0
def main():
    # Initialize OpenAI_ROS ENV
    LoadYamlFileParamsTest(rospackage_name="learning_ros",
                           rel_path_from_package_to_file="config",
                           yaml_file_name="aliengo_stand.yaml")
    env = StartOpenAI_ROS_Environment('AliengoStand-v0')
    time.sleep(3)
    # Initialize PPO agent
    Policy = Policy_net('policy', ob_space=3, act_space=8)
    Old_Policy = Policy_net('old_policy', ob_space=3, act_space=8)
    PPO = PPOTrain(Policy, Old_Policy, gamma=GAMMA)
    saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        obs = env.reset()
        reward = 0
        success_num = 0
        scores = []

        for iteration in range(ITERATION):
            observations = []
            actions = []
            v_preds = []
            rewards = []
            while not rospy.is_shutdown():  # until ros is not shutdown
                # prepare to feed placeholder Policy.obs
                obs = np.stack([obs]).astype(dtype=np.float32)
                act, v_pred = Policy.act(obs=obs, stochastic=True)
                #print('act: ',act, 'v_pred: ',v_pred )

                act = np.asscalar(act)
                v_pred = np.asscalar(v_pred)
                observations.append(obs)
                actions.append(act)
                v_preds.append(v_pred)
                rewards.append(reward)
                #execute according to action
                next_obs, reward, done, _ = env.step(act)
                time.sleep(0.01)
                if done:
                    # next state of terminate state has 0 state value
                    v_preds_next = v_preds[1:] + [0]
                    obs = env.reset()
                    break
                else:
                    obs = next_obs
            #scores store for visualization
            scores.append(sum(rewards))
            #if consectuvely 10 times has high reward end the training
            if sum(rewards) >= 400:
                success_num += 1
                print("Succes number: " + str(success_num))
                if success_num >= 5:
                    saver.save(sess, 'model_train/model_ppo.ckpt')
                    print('Clear!! Model saved.')
                if success_num >= 10:
                    saver.save(sess, 'model_train/model_ppo.ckpt')
                    print('Finished! ')
                    break

            else:
                success_num = 0

            gaes = PPO.get_gaes(rewards=rewards,
                                v_preds=v_preds,
                                v_preds_next=v_preds_next)

            # convert list to numpy array for feeding tf.placeholder
            observations = np.reshape(observations, [len(observations), 3])
            actions = np.array(actions).astype(dtype=np.int32)
            rewards = np.array(rewards).astype(dtype=np.float32)
            v_preds_next = np.array(v_preds_next).astype(dtype=np.float32)
            #calculate generative advantage estimator score
            gaes = np.array(gaes).astype(dtype=np.float32)
            gaes = (gaes - gaes.mean())
            print('gaes', gaes)
            #assign current policy params to previous policy params
            PPO.assign_policy_parameters()

            inp = [observations, actions, rewards, v_preds_next, gaes]

            # PPO train
            for epoch in range(4):
                sample_indices = np.random.randint(
                    low=0, high=observations.shape[0],
                    size=64)  # indices are in [low, high)
                sampled_inp = [
                    np.take(a=a, indices=sample_indices, axis=0) for a in inp
                ]  # sample training data
                PPO.train(obs=sampled_inp[0],
                          actions=sampled_inp[1],
                          rewards=sampled_inp[2],
                          v_preds_next=sampled_inp[3],
                          gaes=sampled_inp[4])
        plt.plot(scores)
        plt.show()
        env.stop()
environment_name = rospy.get_param(
    '/cartpole_v0/task_and_robot_environment_name')
env = StartOpenAI_ROS_Environment(environment_name)
env = DummyVecEnv([lambda: env
                   ])  # The algorithms require a vectorized environment to run

model = TRPO(MlpPolicy,
             env,
             verbose=1,
             tensorboard_log="tensorboard_logs/TRPO_cartpole/")
print(model)

# TRAIN
# model.learn(total_timesteps=1000)
# model.save("trained_models/TRPO")

# TEST
model = model.load("trained_models/TRPO")

obs = env.reset()
rate = rospy.Rate(30)

for t in range(1000):
    print(t)
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    rate.sleep()

env.close()