def sample(self, explore=False):
        self.step += 1

        if self._current_observation_n is None:
            self._current_observation_n = self.env.reset()

        action_n = []

        supplied_observation = []

        mix_observe_0 = tf.one_hot(self._current_observation_n[0], self.env.num_state)

        explore_this_step = False

        if np.random.random() < self._epsilon or explore:
            # print('explore!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
            explore_this_step = True
            self._epsilon = self._epsilon * self._epsilon_decay
            action_0 = np.array([np.random.randint(self.env.action_num)])
        else:
            action_0 = self.agents[0].act(mix_observe_0, self.agents[1])

        supplied_observation.append(mix_observe_0)
        action_n.append(action_0)

        # print('shape of action 0:----------------------------------')
        # print(action_0.shape)


        mix_observe_1 = np.hstack((tf.one_hot(self._current_observation_n[1], self.env.num_state),
                                   tf.one_hot(action_0, self.env.action_num)))
        if explore_this_step:
            self._epsilon = self._epsilon * self._epsilon_decay
            action_1 = np.array([np.random.randint(self.env.action_num)])
        else:
            action_1 = self.agents[1].act(mix_observe_1)
        supplied_observation.append(mix_observe_1)
        action_n.append(action_1)

        action_n = np.asarray(action_n)

        next_observation_n, reward_n, done_n, info = self.env.step(action_n)

        if self._global_reward:
            reward_n = np.array([np.sum(reward_n)] * self.agent_num)

        self._path_length += 1
        self._path_return += np.array(reward_n, dtype=np.float32)
        self._total_samples += 1
        for i, agent in enumerate(self.agents):
            opponent_action = action_n[[j for j in range(len(action_n)) if j != i]].flatten()

            agent.replay_buffer.add_sample(
                observation=supplied_observation[i],
                action=tf.one_hot(action_n[i], self.env.action_num),
                reward=np.float32(reward_n[i]),
                terminal=np.float32(done_n[i]),
                next_observation=np.float32(next_observation_n[i]),
                opponent_action=tf.one_hot(opponent_action, self.env.action_num)
            )

        self._current_observation_n = next_observation_n

        if self.step % (25 * 1000) == 0:
            print("steps: {}, episodes: {}, mean episode reward: {}".format(
                        self.step, len(reward_n), np.mean(reward_n[-1000:])))


        if done_n[0] or done_n[1]:
            print('done:----------------------------------', done_n)
        if np.all(done_n) or self._path_length >= self._max_path_length:
            self._current_observation_n = self.env.reset()
            self._max_path_return = np.maximum(self._max_path_return, self._path_return)
            self._mean_path_return = self._path_return / self._path_length
            self._last_path_return = self._path_return
            self._path_length = 0

            self._path_return = np.zeros(self.agent_num)
            self._n_episodes += 1
            self.log_diagnostics()
            logger.log(tabular)
            logger.dump_all()
        else:
            self._current_observation_n = next_observation_n
        return action_n
    def sample(self, explore=False):
        self.step += 1
        # print(self._current_observation_n)
        if self._current_observation_n is None:
            # print('now updating')
            self._current_observation_n = self.env.reset()
            # print(self._current_observation_n)
        action_n = []

        supplied_observation = []

        # print(self._current_observation_n)
        # print(self._current_observation_n.shape)
        if True:
            '''if explore:
            action_n = self.env.action_spaces.sample()
            sample_follower_input = []
            sample_follower_output = []
            for i in range(num_sample):
                tmp = self.env.action_spaces.sample()
                sample_follower_input.append(tmp[0])
                print((tmp[0]))
                print(np.hstack((self._current_observation_n[1], np.array([tmp[0]]))).shape
                      )
                sample_follower_output.append(
                    self.agents[1].act(np.hstack((self._current_observation_n[1], np.array([tmp[0]])))))
            mix_observe_0 = np.hstack(
                (self._current_observation_n[0], np.array(sample_follower_input), np.array(sample_follower_output)))
            supplied_observation.append(mix_observe_0)
            mix_observe_1 = np.hstack((self._current_observation_n[1], np.array([action_n[0]])))
            supplied_observation.append(mix_observe_1)
            # print("explore!!")
        else:'''
            '''
            for agent, current_observation in zip(self.agents, self._current_observation_n):
                action = agent.act(current_observation.astype(np.float32))
                action_n.append(np.array(action))
            '''
            # sample_follower_input = []
            # sample_follower_output = []
            # for i in range(num_sample):
            #     tmp = self.env.action_spaces.sample()
            #     sample_follower_input.append(np.array([tmp[0]]))
            #     act_1 = np.zeros([self.env.action_num])
            #     act_1[tmp[0]] = 1
            #     print('observation shape:')
                # print(self._current_observation_n.shape)
                # sample_follower_output.append(np.squeeze(
                #     self.agents[1].act(np.hstack((self._current_observation_n[1], act_1))), 0))
            # print(np.hstack((np.array(sample_follower_input), np.array(sample_follower_output))))
            # print(np.array(sample_follower_input))
            mix_observe_0 = tf.one_hot(self._current_observation_n[0], self.env.num_state)
            # supplied_observation.append(mix_observe_0)
            # policy_0 = self.agents[0].policy.get_policy_np(self._current_observation_n[0])
            # print(mix_observe_0)
            # action_0 = np.squeeze(self.agents[0].act(mix_observe_0), 0)
            action_0 = self.agents[0].act(mix_observe_0)
            # action_0 = np.array([0])
            supplied_observation.append(mix_observe_0)

            action_n.append(action_0)
            # print(policy_0.shape)
            # print(action_0)
            mix_observe_1 = np.hstack((tf.one_hot(self._current_observation_n[1], self.env.num_state), tf.one_hot(action_0, self.env.action_num)))
            # policy_1 = self.agents[1].get_policy_np(mix_observe_1)
            # print(policy_1)
            action_1 = self.agents[1].act(mix_observe_1)
            # action_1 = np.array([0])

            supplied_observation.append(mix_observe_1)
            action_n.append(action_1)

            # print('action shape:')
            # print(action_0.shape, action_1.shape, np.array(action_n).shape)

        action_n = np.asarray(action_n)

        next_observation_n, reward_n, done_n, info = self.env.step(action_n)
        # print('done:')
        # print(type(done_n[0]))
        if self._global_reward:
            reward_n = np.array([np.sum(reward_n)] * self.agent_num)

        if action_n[0] == 0:
            print('explore up!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
            print(action_n)
            print(reward_n)

        self._path_length += 1
        self._path_return += np.array(reward_n, dtype=np.float32)
        self._total_samples += 1
        for i, agent in enumerate(self.agents):
            opponent_action = action_n[[j for j in range(len(action_n)) if j != i]].flatten()
            # print("supplied observation:")
            # print(supplied_observation)
            # print('agent action:')
            # print(i, action_n[i].shape)
            agent.replay_buffer.add_sample(
                # observation=self._current_observation_n[i].astype(np.float32),
                observation=supplied_observation[i],
                # action=action_n[i].astype(np.float32),
                action = tf.one_hot(action_n[i], self.env.action_num),
                # reward=reward_n[i].astype(np.float32),
                reward = np.float32(reward_n[i]),
                # terminal=done_n[i].astype(np.float32),
                terminal=np.float32(done_n[i]),
                # next_observation=next_observation_n[i].astype(np.float32),
                next_observation=np.float32(next_observation_n[i]),
                # opponent_action=opponent_action.astype(np.float32)
                opponent_action=tf.one_hot(opponent_action, self.env.action_num)
            )

        self._current_observation_n = next_observation_n
        # for i, rew in enumerate(reward_n):
        #     self.episode_rewards[-1] += rew
        #     self.agent_rewards[i][-1] += rew

        #if self.step % (25 * 1000) == 0:
        #    print("steps: {}, episodes: {}, mean episode reward: {}".format(
        #                self.step, len(self.episode_rewards), np.mean(self.episode_rewards[-1000:])))

        if np.all(done_n) or self._path_length >= self._max_path_length:
            self._current_observation_n = self.env.reset()
            self._max_path_return = np.maximum(self._max_path_return, self._path_return)
            self._mean_path_return = self._path_return / self._path_length
            self._last_path_return = self._path_return
            self._path_length = 0

            self._path_return = np.zeros(self.agent_num)
            self._n_episodes += 1
            self.log_diagnostics()
            logger.log(tabular)
            logger.dump_all()
        else:
            self._current_observation_n = next_observation_n
    def sample(self, explore=False):
        self.step += 1
        # print(self._current_observation_n)
        if self._current_observation_n is None:
            # print('now updating')
            self._current_observation_n = self.env.reset()
            # print(self._current_observation_n)
        action_n = []

        supplied_observation = []

        # print(self._current_observation_n)
        # print(self._current_observation_n.shape)
        if explore:
            mix_observe_0 = tf.one_hot(self._current_observation_n[0],
                                       self.env.num_state)
            # action_0 = self.agents[0].act(mix_observe_0)
            supplied_observation.append(mix_observe_0)
            # action_n.append(action_0)

            mix_observe_1 = tf.one_hot(self._current_observation_n[0],
                                       self.env.num_state)
            # action_1 = self.agents[1].act(mix_observe_1)
            supplied_observation.append(mix_observe_1)

            action_n = self.env.action_spaces.sample()
            action_n = action_n.reshape(-1, 1)
            # print(action_n.shape)
            # action_n.append(action_1)
        else:
            mix_observe_0 = tf.one_hot(self._current_observation_n[0],
                                       self.env.num_state)
            action_0 = self.agents[0].act(mix_observe_0)
            supplied_observation.append(mix_observe_0)
            action_n.append(action_0)

            mix_observe_1 = tf.one_hot(self._current_observation_n[0],
                                       self.env.num_state)
            action_1 = self.agents[1].act(mix_observe_1)
            supplied_observation.append(mix_observe_1)
            action_n.append(action_1)

            # print('action shape:')
            # print(action_0.shape, action_1.shape, np.array(action_n).shape)

        action_n = np.asarray(action_n)

        next_observation_n, reward_n, done_n, info = self.env.step(action_n)
        # print('done:')
        # print(type(done_n[0]))
        if self._global_reward:
            reward_n = np.array([np.sum(reward_n)] * self.agent_num)

        self._path_length += 1
        self._path_return += np.array(reward_n, dtype=np.float32)
        self._total_samples += 1
        for i, agent in enumerate(self.agents):
            opponent_action = action_n[[
                j for j in range(len(action_n)) if j != i
            ]].flatten()
            # print("supplied observation:")
            # print(supplied_observation)
            # print('agent action:')
            # print(i, action_n[i].shape)
            agent.replay_buffer.add_sample(
                observation=supplied_observation[i],
                action=tf.one_hot(action_n[i], self.env.action_num),
                reward=np.float32(reward_n[i]),
                # terminal=done_n[i].astype(np.float32),
                terminal=np.float32(done_n[i]),
                # next_observation=next_observation_n[i].astype(np.float32),
                next_observation=np.float32(next_observation_n[i]),
                # opponent_action=opponent_action.astype(np.float32)
                opponent_action=tf.one_hot(opponent_action,
                                           self.env.action_num))

        self._current_observation_n = next_observation_n

        if self.step % (25 * 1000) == 0:
            print("steps: {}, episodes: {}, mean episode reward: {}".format(
                self.step, len(reward_n), np.mean(reward_n[-1000:])))

        if np.all(done_n) or self._path_length >= self._max_path_length:
            self._current_observation_n = self.env.reset()
            self._max_path_return = np.maximum(self._max_path_return,
                                               self._path_return)
            self._mean_path_return = self._path_return / self._path_length
            self._last_path_return = self._path_return
            self._path_length = 0

            self._path_return = np.zeros(self.agent_num)
            self._n_episodes += 1
            self.log_diagnostics()
            logger.log(tabular)
            logger.dump_all()
        else:
            self._current_observation_n = next_observation_n
    def sample(self, explore=False):
        self.step += 1
        if self._current_observation_n is None:
            self._current_observation_n = self.env.reset()
        action_n = []

        supplied_observation = []

        # print(self._current_observation_n)
        # print(self._current_observation_n.shape)
        if explore:
            action_n = self.env.action_spaces.sample()

            mix_observe_0 = self._current_observation_n[0]

            supplied_observation.append(mix_observe_0)
            supplied_observation.append(self._current_observation_n[1])

        else:
            mix_observe_0 = self._current_observation_n[0]
            action_0 = self.agents[0].act(mix_observe_0)
            supplied_observation.append(mix_observe_0)
            action_n.append(action_0)

            action_1 = self.agents[1].act(self._current_observation_n[1])

            supplied_observation.append(self._current_observation_n[1])
            action_n.append(action_1)
            action_n = np.asarray(action_n)

        next_observation_n, reward_n, done_n, info = self.env.step(action_n)
        if self._global_reward:
            reward_n = np.array([np.sum(reward_n)] * self.agent_num)

        self._path_length += 1
        self._path_return += np.array(reward_n, dtype=np.float32)
        self._total_samples += 1
        for i, agent in enumerate(self.agents):
            opponent_action = action_n[[
                j for j in range(len(action_n)) if j != i
            ]].flatten()
            agent.replay_buffer.add_sample(
                observation=supplied_observation[i],
                action=action_n[i].astype(np.float32),
                reward=reward_n[i].astype(np.float32),
                terminal=done_n[i].astype(np.float32),
                next_observation=next_observation_n[i].astype(np.float32),
                opponent_action=opponent_action.astype(np.float32))

        self._current_observation_n = next_observation_n

        if np.all(done_n) or self._path_length >= self._max_path_length:
            self._current_observation_n = self.env.reset()
            self._max_path_return = np.maximum(self._max_path_return,
                                               self._path_return)
            self._mean_path_return = self._path_return / self._path_length
            self._last_path_return = self._path_return
            self._path_length = 0

            self._path_return = np.zeros(self.agent_num)
            self._n_episodes += 1
            self.log_diagnostics()
            logger.log(tabular)
            logger.dump_all()
        else:
            self._current_observation_n = next_observation_n
    def sample(self, explore=False):
        self.step += 1
        if self._current_observation_n is None:
            self._current_observation_n = self.env.reset()
        action_n = []

        supplied_observation = []

        # print(self._current_observation_n)
        # print(self._current_observation_n.shape)
        if explore:
            action_n = self.env.action_spaces.sample()

            mix_observe_0 = self._current_observation_n[0]
            # action_0 = self.agents[0].act(mix_observe_0)
            supplied_observation.append(mix_observe_0)
            # action_n.append(action_0)

            mix_observe_1 = np.hstack(
                (self._current_observation_n[1], action_n[0]))
            # action_1 = self.agents[1].act(mix_observe_1)
            supplied_observation.append(self._current_observation_n[1])

            # action_n = self.env.action_spaces.sample()
            # action_n = action_n.reshape(-1, 1)

        else:
            mix_observe_0 = self._current_observation_n[0]
            action_0 = self.agents[0].act(mix_observe_0)
            supplied_observation.append(mix_observe_0)
            action_n.append(action_0)

            mix_observe_1 = (self._current_observation_n[1], action_0)
            action_1 = self.agents[1].act(mix_observe_1)

            # print(self._current_observation_n[1])
            supplied_observation.append(self._current_observation_n[1])
            action_n.append(action_1)
            action_n = np.asarray(action_n)

            # print("supplied-observations:")
            # print(supplied_observation)
            # print("supplied-observations end!!!!!!!!!!!!!!!!!!!!!!!1:")

        print('action:-----------------')
        print(action_n)

        next_observation_n, reward_n, done_n, info = self.env.step(action_n)
        if self._global_reward:
            reward_n = np.array([np.sum(reward_n)] * self.agent_num)

        self._path_length += 1
        self._path_return += np.array(reward_n, dtype=np.float32)
        self._total_samples += 1
        for i, agent in enumerate(self.agents):
            opponent_action = action_n[[
                j for j in range(len(action_n)) if j != i
            ]].flatten()
            agent.replay_buffer.add_sample(
                observation=supplied_observation[i],
                action=action_n[i].astype(np.float32),
                reward=reward_n[i].astype(np.float32),
                terminal=done_n[i].astype(np.float32),
                next_observation=next_observation_n[i].astype(np.float32),
                opponent_action=opponent_action.astype(np.float32))

        self._current_observation_n = next_observation_n
        # for i, rew in enumerate(reward_n):
        #     self.episode_rewards[-1] += rew
        #     self.agent_rewards[i][-1] += rew

        # if self.step % (25 * 1000) == 0:
        #     print("steps: {}, episodes: {}, mean episode reward: {}".format(
        #                 self.step, len(self.episode_rewards), np.mean(self.episode_rewards[-1000:])))

        if np.all(done_n) or self._path_length >= self._max_path_length:
            self._current_observation_n = self.env.reset()
            self._max_path_return = np.maximum(self._max_path_return,
                                               self._path_return)
            self._mean_path_return = self._path_return / self._path_length
            self._last_path_return = self._path_return
            self._path_length = 0

            self._path_return = np.zeros(self.agent_num)
            self._n_episodes += 1
            self.log_diagnostics()
            logger.log(tabular)
            logger.dump_all()
        else:
            self._current_observation_n = next_observation_n
    def sample(self, explore=False):
        self.step += 1
        # print(self._current_observation_n)

        if self._current_observation_n is None:
            # print('now updating')
            self._current_observation_n = self.env.reset()

        action_n = []
        supplied_observation = []

        observations = np.zeros((2, self.env.num_state))
        next_observations = np.zeros((2, self.env.num_state))
        if self.env.sim_step >= self.env.num_state - 3:
            print('wrong')
        observations[0][self.env.sim_step] = 1
        observations[1][self.env.sim_step] = 1
        next_observations[0][self.env.sim_step + 1] = 1
        next_observations[1][self.env.sim_step + 1] = 1
        relative_info = np.zeros((2, 2))
        speed_max = 40
        velocity_range = 2 * 40
        x_position_range = speed_max
        delta_dx = self.env.road.vehicles[1].position[
            0] - self.env.road.vehicles[0].position[0]
        delta_vx = self.env.road.vehicles[1].velocity - self.env.road.vehicles[
            0].velocity
        relative_info[0][0] = utils.remap(
            delta_dx, [-x_position_range, x_position_range], [-1, 1])
        relative_info[0][1] = utils.remap(delta_vx,
                                          [-velocity_range, velocity_range],
                                          [-1, 1])
        relative_info[1][0] = -relative_info[0][0]
        relative_info[1][1] = -relative_info[0][1]
        observations[:, -2:] = relative_info
        if explore:
            for i in range(self.agent_num):
                action_n.append([np.random.randint(0, self.env.action_num)])

            for i in range(self.leader_num):
                supplied_observation.append(observations[i])
            for i in range(self.leader_num, self.env.agent_num):
                mix_obs = np.hstack(
                    (observations[i],
                     tf.one_hot(action_n[0][0],
                                self.env.action_num))).reshape(1, -1)
                supplied_observation.append(mix_obs)

        else:
            for i in range(self.leader_num):
                supplied_observation.append(observations[i])
                action_n.append(self.train_agents[0].act(
                    observations[i].reshape(1, -1)))

            for i in range(self.leader_num, self.env.agent_num):
                mix_obs = np.hstack(
                    (observations[i],
                     tf.one_hot(action_n[0][0],
                                self.env.action_num))).reshape(1, -1)
                supplied_observation.append(mix_obs)
                follower_action = self.train_agents[1].act(
                    mix_obs.reshape(1, -1))
                action_n.append(follower_action)

        action_n = np.asarray(action_n)

        pres_valid_conditions_n = []
        next_valid_conditions_n = []

        #self.env.render()
        '''
        for i in range(5):
            for j in range(5):
                print("q value for upper agent ", i, j, self.train_agents[0]._qf.get_values(np.hstack((observations[0], tf.one_hot(i, self.env.action_num), tf.one_hot(j, self.env.action_num))).reshape(1, -1)))
        print()
        for i in range(5):
            for j in range(5):
                print("q value for lower agent ", i, j, self.train_agents[1]._qf.get_values(np.hstack((observations[1], tf.one_hot(i, self.env.action_num), tf.one_hot(j, self.env.action_num))).reshape(1, -1))) 
        print('a0 = ', action_n[0])
        print('a1 = ', action_n[1])
        print(self.env.road.vehicles[0].position[0], self.env.road.vehicles[1].position[0])
        '''
        next_observation_n, reward_n, done_n, info = self.env.step(action_n)

        delta_dx = self.env.road.vehicles[1].position[
            0] - self.env.road.vehicles[0].position[0]
        delta_vx = self.env.road.vehicles[1].velocity - self.env.road.vehicles[
            0].velocity
        relative_info[0][0] = utils.remap(
            delta_dx, [-x_position_range, x_position_range], [-1, 1])
        relative_info[0][1] = utils.remap(delta_vx,
                                          [-velocity_range, velocity_range],
                                          [-1, 1])
        relative_info[1][0] = -relative_info[0][0]
        relative_info[1][1] = -relative_info[0][1]
        next_observations[:, -2:] = relative_info

        if self._global_reward:
            reward_n = np.array([np.sum(reward_n)] * self.agent_num)

        self._path_length += 1
        self._path_return += np.array(reward_n, dtype=np.float32)
        self._total_samples += 1

        opponent_action = np.array(action_n[[j for j in range(len(action_n))
                                             ]].flatten())
        for i, agent in enumerate(self.agents):
            agent.replay_buffer.add_sample(
                # observation=self._current_observation_n[i].astype(np.float32),
                observation=supplied_observation[i],
                # action=action_n[i].astype(np.float32),
                action=tf.one_hot(action_n[i], self.env.action_num),
                # reward=reward_n[i].astype(np.float32),
                reward=np.float32(reward_n[i]),
                # terminal=done_n[i].astype(np.float32),
                terminal=np.float32(done_n[i]),
                # next_observation=next_observation_n[i].astype(np.float32),
                next_observation=np.float32(next_observations[i]),
                # opponent_action=opponent_action.astype(np.float32)
                opponent_action=np.int32(opponent_action),
            )

        self._current_observation_n = next_observation_n

        if self.step % (25 * 1000) == 0:
            print("steps: {}, episodes: {}, mean episode reward: {}".format(
                self.step, len(reward_n), np.mean(reward_n[-1000:])))

        if np.all(done_n) or self._path_length >= self._max_path_length:
            self._current_observation_n = self.env.reset()
            self._max_path_return = np.maximum(self._max_path_return,
                                               self._path_return)
            self._mean_path_return = self._path_return / self._path_length
            self._last_path_return = self._path_return
            self._path_length = 0

            self._path_return = np.zeros(self.agent_num)
            self._n_episodes += 1
            self.log_diagnostics()
            logger.log(tabular)
            logger.dump_all()
        else:
            self._current_observation_n = next_observation_n
    def sample(self, explore=False):
        self.step += 1
        # print(self._current_observation_n)
        
        if self._current_observation_n is None:
            # print('now updating')
            self._current_observation_n = self.env.reset()
            # print(self._current_observation_n)
        
        action_n = []
        supplied_observation = []
        #print(self._current_observation_n.shape)
        # print(self._current_observation_n)
        # print(self._current_observation_n.shape)
        #mix_observe_0 = tf.one_hot(self._current_observation_n[0], self.env.num_state)

        if explore:
            for i in range(self.agent_num):
                if self.env.is_vehicles_valid[i]:
                    action_n.append([np.random.randint(0, self.env.action_num)])
                else:
                    action_n.append([self.idle_action])  # idle action 
            for i in range(self.leader_num):
                supplied_observation.append(self._current_observation_n[i])
            for i in range(self.leader_num, self.env.agent_num):      
                v = self.env.road.vehicles[i]
                #print(v)
                closest_leaders = self.env.road.closest_leader_vehicles_to(v, 1)
                # ordered leader idx according to distance to leader vehicle
                
                #print(closest_leaders)
                leader_actions_concat = []
                for leader in closest_leaders:
                    leader_actions_concat = np.append(leader_actions_concat, tf.one_hot(action_n[leader.index], self.env.action_num))
                    #print(leader_actions_concat)
                #print(leader_actions_concat)
                #print(leader_actions_concat.reshape(1, -1))
                #print(len(self._current_observation_n[0]))
                #print(self._current_observation_n[i])
                mix_obs = np.hstack((self._current_observation_n[i], leader_actions_concat)).reshape(1, -1)
                ##print(mix_obs.shape)
                #print(mix_obs.shape[1])
                supplied_observation.append(mix_obs)
        else:
            for i in range(self.leader_num):
                supplied_observation.append(self._current_observation_n[i])
                if self.env.is_vehicles_valid[i]:
                    action_n.append(self.train_agents[self.leader_idx].act(self._current_observation_n[i].reshape(1, -1)))
                else:
                    action_n.append([self.idle_action])  #idle action
            
            for i in range(self.leader_num, self.env.agent_num):
                
                v = self.env.road.vehicles[i]
                close_leaders = self.env.road.closest_leader_vehicles_to(v, 1)
                # ordered leader idx according to distance to leader vehicle
                leader_actions_concat = []
                for leader in close_leaders:
                    leader_actions_concat = np.append(leader_actions_concat, tf.one_hot(action_n[leader.index], self.env.action_num))
                mix_obs = np.hstack((self._current_observation_n[i], leader_actions_concat)).reshape(1, -1)
                supplied_observation.append(mix_obs)
                if self.env.is_vehicles_valid[i]:
                    follower_action = self.train_agents[self.follower_idx].act(mix_obs.reshape(1, -1))
                    action_n.append(follower_action)
                else:
                    action_n.append([self.idle_action]) # idle action
            # print('action shape:')
            # print(action_0.shape, action_1.shape, np.array(action_n).shape)
        #supplied_observation.append(mix_observe_0)
        #supplied_observation.append(mix_observe_1)
        action_n = np.asarray(action_n)
        
        pres_valid_conditions_n = []
        next_valid_conditions_n = []
        #obs_v_idxes_n = []
        #next_obs_v_idxes_n = []
        
        for i, agent in enumerate(self.agents):  
            ''' 
            v = self.env.road.vehicles[i]
            obs_v_idxes = []
            close_leaders = self.env.road.closest_leader_vehicles_to(v, self.level_agent_num)
            for i in range(len(close_leaders)):
                obs_v_idxes.append(close_leaders[i].index)
            close_followers = self.env.road.closest_follower_vehicles_to(v, self.level_agent_num)
            for i in range(len(close_followers)):
                obs_v_idxes.append(close_followers[i].index)
            '''
            if not self.env.is_vehicles_valid[i]:
                pres_valid_conditions_n.append(0)
            else:
                pres_valid_conditions_n.append(1)
            #obs_v_idxes_n.append(obs_v_idxes)
        #print(pres_valid_conditions_n)
        next_observation_n, reward_n, done_n, info = self.env.step(action_n)
        #self.rewards_record.append(reward_n)
        self.env.render()
        
        for i, agent in enumerate(self.agents):
            '''            
            v = self.env.road.vehicles[i]
            next_obs_v_idxes = []
            close_leaders = self.env.road.closest_leader_vehicles_to(v, self.level_agent_num)
            for i in range(len(close_leaders)):
                next_obs_v_idxes.append(close_leaders[i].index)
            close_followers = self.env.road.closest_follower_vehicles_to(v, self.level_agent_num)
            for i in range(len(close_followers)):
                next_obs_v_idxes.append(close_followers[i].index)
            '''
            if not self.env.is_vehicles_valid[i]:
                next_valid_conditions_n.append(0)
            else:
                next_valid_conditions_n.append(1)
            #next_obs_v_idxes_n.append(next_obs_v_idxes)
        #obs_v_idxes_n = np.array(obs_v_idxes_n)
        #next_obs_v_idxes_n = np.array(next_obs_v_idxes_n)
        
        if self._global_reward:
            reward_n = np.array([np.sum(reward_n)] * self.agent_num)

 

        self._path_length += 1
        self._path_return += np.array(reward_n, dtype=np.float32)
        self._total_samples += 1
        opponent_action = np.array(action_n[[j for j in range(len(action_n))]].flatten())
        for i, agent in enumerate(self.agents):            
            #opponent_action = action_n[[j for j in range(len(action_n))]].flatten()
            #q_actions_concat = np.array(tf.one_hot(q_actions_n[i], self.env.action_num)).reshape(1, -1)
            #print(q_actions_concat.shape)
            #opponent_action_concat = np.array(tf.one_hot(opponent_action, self.env.action_num)).reshape(1, -1)
            #print(obs_v_idxes_n[i])
            #print(tf.one_hot(action_n[i], self.env.action_num))
            agent.replay_buffer.add_sample(
                # observation=self._current_observation_n[i].astype(np.float32),
                observation=supplied_observation[i],
                # action=action_n[i].astype(np.float32),
                action = tf.one_hot(action_n[i], self.env.action_num),
                # reward=reward_n[i].astype(np.float32),
                reward = np.float32(reward_n[i]),
                # terminal=done_n[i].astype(np.float32),
                terminal=np.float32(done_n[i]),
                # next_observation=next_observation_n[i].astype(np.float32),
                next_observation=np.float32(next_observation_n[i]),
                # opponent_action=opponent_action.astype(np.float32)
                opponent_action=np.int32(opponent_action),
                #Q_actions=q_actions_concat,
                #obs_v_idxes=np.int16(obs_v_idxes_n[i]),
                #next_obs_v_idxes=np.int16(next_obs_v_idxes_n[i]),
                pres_valid_conditions=np.int16(pres_valid_conditions_n[i]),
                next_valid_conditions=np.int16(next_valid_conditions_n[i]),
            )
        
             
        self._current_observation_n = next_observation_n
        # for i, rew in enumerate(reward_n):
        #     self.episode_rewards[-1] += rew
        #     self.agent_rewards[i][-1] += rew
        #print("Correct merge count percentage: ", self.env.correct_merge_count / self.env.merge_count)
        if self.step % (25 * 1000) == 0:
            print("steps: {}, episodes: {}, mean episode reward: {}".format(
                        self.step, len(reward_n), np.mean(reward_n[-1000:])))

        if np.all(done_n) or self._path_length >= self._max_path_length:
            self._current_observation_n = self.env.reset()

            #self.env.merge_count += 1
            self._max_path_return = np.maximum(self._max_path_return, self._path_return)
            self._mean_path_return = self._path_return / self._path_length
            self._last_path_return = self._path_return
            self._path_length = 0

            self._path_return = np.zeros(self.agent_num)
            self._n_episodes += 1
            self.log_diagnostics()
            logger.log(tabular)
            logger.dump_all()
        else:
            self._current_observation_n = next_observation_n