Example #1
0
    def work(self,
             max_episode_length,
             gamma,
             sess,
             coord=None,
             saver=None,
             max_episodes=None,
             batch_size=25):
        episode_count = sess.run(self.global_episodes)
        total_steps = 0
        print("Starting worker " + str(self.number))
        # with sess.as_default(), sess.graph.as_default():
        prev_clock = time()
        if coord is None:
            coord = sess

        already_calculated_actions = False
        stats1, stats2 = [], []
        while not coord.should_stop():
            sess.run(self.update_local_ops)

            episode_buffer = [[] for _ in range(self.number_of_agents)]
            episode_comm_maps = [[] for _ in range(self.number_of_agents)]
            episode_values = [[] for _ in range(self.number_of_agents)]
            episode_reward = 0
            episode_step_count = 0

            v_l, p_l, e_l, g_n, v_n = get_empty_loss_arrays(
                self.number_of_agents)
            partial_obs = [None for _ in range(self.number_of_agents)]
            partial_mess_rec = [None for _ in range(self.number_of_agents)]
            sent_message = [None for _ in range(self.number_of_agents)]
            mgrad_per_received = [None for _ in range(self.number_of_agents)]

            # start new epi
            current_screen = self.env.reset()
            # print(current_screen)
            for i in range(self.number_of_agents):
                if self.s_size == 6:
                    current_screen[i] = np.hstack([
                        current_screen[i][0:4],
                        current_screen[i][4 + i * 2:6 + i * 2]
                    ])
                else:
                    current_screen[i] = current_screen[i][0:self.s_size]

            current_screen_central = np.hstack(current_screen)
            arrayed_current_screen_central = [
                current_screen_central for _ in range(self.number_of_agents)
            ]
            for i in range(self.number_of_agents):
                comm_map = list(range(self.number_of_agents))
                comm_map.remove(i)
                episode_comm_maps[i].append(comm_map)

            if self.is_chief and self.display:
                self.env.render()

            curr_comm = [[] for _ in range(self.number_of_agents)]
            for curr_agent in range(self.number_of_agents):
                for from_agent in range(self.number_of_agents - 1):
                    curr_comm[curr_agent].extend([0] * self.message_size)

            stats1_temp = 0

            for episode_step_count in range(max_episode_length):
                # feedforward pass
                if not already_calculated_actions:
                    action_distribution, message = sess.run(
                        [self.local_AC.policy, self.local_AC.message],
                        feed_dict={
                            self.local_AC.inputs: current_screen,
                            self.local_AC.inputs_comm: curr_comm
                        })
                    actions = [
                        np.random.choice(self.action_indexes,
                                         p=act_distribution)
                        for act_distribution in action_distribution
                    ]
                    actions_one_hot = [
                        self.actions_one_hot[act] for act in actions
                    ]

                    # message gauss noise
                    if self.comm_gaussian_noise != 0:
                        for index in range(len(message)):
                            message[index] += np.random.normal(
                                0, self.comm_gaussian_noise)

                    if self.critic_action:
                        for agent in range(self.number_of_agents):
                            actions_one_hot = one_hot_encoding(
                                actions[0:agent] + actions[agent + 1:],
                                self.number_of_actions)
                            arrayed_current_screen_central[
                                agent] = arrayed_current_screen_central[
                                    agent] + actions_one_hot
                    if self.critic_comm:
                        for agent in range(self.number_of_agents):
                            arrayed_current_screen_central[
                                agent] = arrayed_current_screen_central[
                                    agent] + curr_comm[agent]
                already_calculated_actions = False
                value = sess.run(self.local_AC.value,
                                 feed_dict={
                                     self.local_AC.inputs_central:
                                     arrayed_current_screen_central
                                 })

                previous_screen = current_screen
                arrayed_previous_screen_central = arrayed_current_screen_central
                previous_comm = curr_comm

                # Watch environment
                current_screen, reward, terminal, info = self.env.step(
                    actions_one_hot)
                for info_agent in info['n']:
                    stats1_temp += (info_agent[1] - 1) / 2

                for i in range(self.number_of_agents):
                    current_screen[i] = current_screen[i][0:self.s_size]
                terminal = np.sum(reward) > (
                    -0.25) * self.number_of_agents or stats1_temp > 5
                current_screen_central = np.hstack(current_screen)
                arrayed_current_screen_central = [
                    current_screen_central
                    for _ in range(self.number_of_agents)
                ]

                this_turns_comm_map = []
                for i in range(self.number_of_agents):
                    # 50% chance of no comms
                    surviving_comms = list(range(self.number_of_agents))
                    surviving_comms.remove(i)
                    for index in range(len(surviving_comms)):
                        if random.random(
                        ) < self.comm_delivery_failure_chance:  # chance of failure comms
                            surviving_comms[index] = -1
                    episode_comm_maps[i].append(surviving_comms)
                    this_turns_comm_map.append(surviving_comms)
                curr_comm = self.output_mess_to_input_mess(
                    message, this_turns_comm_map)

                if self.is_chief and self.display:
                    self.env.render()
                    sleep(0.2)
                episode_reward += sum(
                    reward) if self.spread_rewards else reward

                # jumbles comms
                if self.comm_jumble_chance != 0:
                    for i in range(self.number_of_agents):
                        joint_comm = [0] * self.message_size
                        for index in range(len(curr_comm[i])):
                            joint_comm[
                                index %
                                self.message_size] += curr_comm[i][index]
                        jumble = False
                        for index in range(len(curr_comm[i])):
                            if index % self.message_size == 0:
                                # only jumble messages that got received
                                jumble = curr_comm[i][
                                    index] != 0 and random.random(
                                    ) < self.comm_jumble_chance
                            if jumble:
                                curr_comm[i][index] = joint_comm[
                                    index % self.message_size]

                for i in range(self.number_of_agents):
                    episode_buffer[i].append([
                        previous_screen[i], arrayed_previous_screen_central[i],
                        previous_comm[i], actions[i], message[i],
                        reward[i] if self.spread_rewards else reward,
                        current_screen[i], curr_comm[i], terminal, value[i]
                    ])
                    episode_values[i].append(np.max(value[i]))

                # If the episode hasn't ended, but the experience buffer is full, then we make an update step
                # using that experience rollout.
                if len(episode_buffer[0]) == batch_size and not terminal and \
                                episode_step_count < max_episode_length - 1:
                    # feedforward pass
                    action_distribution, message = sess.run(
                        [self.local_AC.policy, self.local_AC.message],
                        feed_dict={
                            self.local_AC.inputs: current_screen,
                            self.local_AC.inputs_comm: curr_comm
                        })
                    actions = [
                        np.random.choice(self.action_indexes,
                                         p=act_distribution)
                        for act_distribution in action_distribution
                    ]
                    actions_one_hot = [
                        self.actions_one_hot[act] for act in actions
                    ]

                    if self.critic_action:
                        for agent in range(self.number_of_agents):
                            actions_one_hot = one_hot_encoding(
                                actions[0:agent] + actions[agent + 1:],
                                self.number_of_actions)
                            arrayed_current_screen_central[
                                agent] = arrayed_current_screen_central[
                                    agent] + actions_one_hot
                    if self.critic_comm:
                        for agent in range(self.number_of_agents):
                            arrayed_current_screen_central[
                                agent] = arrayed_current_screen_central[
                                    agent] + curr_comm[agent]
                    already_calculated_actions = True
                    v1 = sess.run(self.local_AC.value,
                                  feed_dict={
                                      self.local_AC.inputs_central:
                                      arrayed_current_screen_central
                                  })

                    for i in range(self.number_of_agents):
                        partial_obs[i], partial_mess_rec[i], sent_message[i], mgrad_per_received[i], \
                        v_l[i], p_l[i], e_l[i], g_n[i], v_n[i] = \
                            self.train_weights_and_get_comm_gradients(episode_buffer[i], sess, gamma, self.local_AC,
                                                                      bootstrap_value=v1[i][0])

                    if self.comm:
                        mgrad_per_sent = self.input_mloss_to_output_mloss(
                            batch_size - 1, mgrad_per_received,
                            episode_comm_maps)

                        # start a new mini batch with only the last sample in the comm map
                        temp_episode_comm_maps = []
                        for i in range(self.number_of_agents):
                            temp_episode_comm_maps.append(
                                [episode_comm_maps[i][-1]])
                        episode_comm_maps = temp_episode_comm_maps
                        for i in range(self.number_of_agents):
                            self.apply_comm_gradients(partial_obs[i],
                                                      partial_mess_rec[i],
                                                      sent_message[i],
                                                      mgrad_per_sent[i], sess,
                                                      self.local_AC)

                    # print("Copying global networks to local networks")
                    sess.run(self.update_local_ops)

                    # reset episode buffers. keep last value to be used for t_minus_1 message loss
                    temp_episode_buffer = []
                    for i in range(self.number_of_agents):
                        temp_episode_buffer.append([episode_buffer[i][-1]])
                    episode_buffer = temp_episode_buffer

                # Measure time and increase episode step count
                total_steps += 1
                if total_steps % 2000 == 0:
                    new_clock = time()
                    print(2000.0 / (new_clock - prev_clock), "it/s,   ")
                    prev_clock = new_clock

                # If both prey and predator have acknowledged game is over, then break from episode
                if terminal:
                    break

            stats1.append(stats1_temp)
            stats2.append(-reward[0] / self.number_of_agents)

            # print("0ver ",episode_step_count,episode_reward)
            self.episode_rewards.append(episode_reward)
            self.episode_lengths.append(episode_step_count)
            self.episode_mean_values.append(np.mean(episode_values))

            # Update the network using the experience buffer at the end of the episode.
            for i in range(self.number_of_agents):
                partial_obs[i], partial_mess_rec[i], sent_message[i], mgrad_per_received[i], \
                v_l[i], p_l[i], e_l[i], g_n[i], v_n[i] = \
                    self.train_weights_and_get_comm_gradients(episode_buffer[i], sess, gamma, self.local_AC)

            if self.comm and len(mgrad_per_received[0]) != 0:
                mgrad_per_sent = self.input_mloss_to_output_mloss(
                    len(mgrad_per_received[0]), mgrad_per_received,
                    episode_comm_maps)

                for i in range(self.number_of_agents):
                    self.apply_comm_gradients(partial_obs[i],
                                              partial_mess_rec[i],
                                              sent_message[i],
                                              mgrad_per_sent[i], sess,
                                              self.local_AC)

            # print("Copying global networks to local networks")
            sess.run(self.update_local_ops)

            # Periodically save gifs of episodes, model parameters, and summary statistics.
            if episode_count % 5 == 0:

                # Save statistics for TensorBoard
                mean_length = np.mean(self.episode_lengths[-5:])
                mean_reward = np.mean(self.episode_rewards[-5:])
                mean_value = np.mean(self.episode_mean_values[-5:])

                if self.is_chief and episode_count % 10 == 0:
                    print("length", mean_length, "reward", np.max(reward))

                # Save current model
                if self.is_chief and saver is not None and episode_count % 500 == 0:
                    saver.save(
                        sess, self.model_path + '/model-' +
                        str(episode_count) + '.cptk')
                    print("Saved Model")

                summary = tf.Summary()
                summary.value.add(
                    tag='Perf/Length',
                    simple_value=float(mean_length))  # avg episode length
                summary.value.add(
                    tag='Perf/Reward',
                    simple_value=float(mean_reward))  # avg reward
                summary.value.add(
                    tag='Perf/Value', simple_value=float(
                        mean_value))  # avg episode value_predator
                summary.value.add(tag='Losses/Value Loss',
                                  simple_value=float(
                                      np.mean(v_l)))  # value_loss
                summary.value.add(tag='Losses/Policy Loss',
                                  simple_value=float(
                                      np.mean(p_l)))  # policy_loss
                summary.value.add(tag='Losses/Entropy',
                                  simple_value=float(np.mean(e_l)))  # entropy
                summary.value.add(tag='Losses/Grad Norm',
                                  simple_value=float(
                                      np.mean(g_n)))  # grad_norms
                summary.value.add(tag='Losses/Var Norm',
                                  simple_value=float(
                                      np.mean(v_n)))  # var_norms
                self.summary_writer.add_summary(summary, episode_count)
                self.summary_writer.flush()

            # Update episode count
            if self.is_chief:
                episode_count = sess.run(self.increment)
                if episode_count % 50 == 0:
                    print("Global episodes @", episode_count)

                if max_episodes is not None and episode_count > max_episodes:
                    coord.request_stop()
            else:
                episode_count = sess.run(self.global_episodes)

        self.env.close()
        print(stats1)
        print(stats2)
        print(np.mean(stats1), np.mean(stats2))
Example #2
0
    def work(self,
             max_episode_length,
             gamma,
             sess,
             coord=None,
             saver=None,
             max_episodes=None,
             batch_size=25):
        episode_count = sess.run(self.global_episodes)
        total_steps = 0
        print("Starting worker " + str(self.number))
        # with sess.as_default(), sess.graph.as_default():
        prev_clock = time()
        if coord is None:
            coord = sess
        """comms_evol = [[], [], [], [], [], [], [], []]
        if self.display and self.is_chief:
            mpl.ion()
            mpl.pause(0.001)"""

        already_calculated_actions = False
        while not coord.should_stop():
            sess.run(self.update_local_ops)

            episode_buffer = [[] for _ in range(self.number_of_agents)]
            episode_comm_maps = [[] for _ in range(self.number_of_agents)]
            episode_values = [[] for _ in range(self.number_of_agents)]
            episode_reward = 0
            episode_step_count = 0
            action_indexes = list(range(self.env.max_actions))

            v_l, p_l, e_l, g_n, v_n = get_empty_loss_arrays(
                self.number_of_agents)
            partial_obs = [None for _ in range(self.number_of_agents)]
            partial_mess_rec = [None for _ in range(self.number_of_agents)]
            sent_message = [None for _ in range(self.number_of_agents)]
            mgrad_per_received = [None for _ in range(self.number_of_agents)]

            # start new epi
            current_screen, info = self.env.reset()
            arrayed_current_screen_central = info["state_central"]
            for i in range(self.number_of_agents):
                episode_comm_maps[i].append(current_screen[i][1:4])

                # replace state to just show whether cars were there or not, and not which cars
                for neighbor_index in range(1, 4):
                    if current_screen[i][neighbor_index] >= 0:
                        current_screen[i][neighbor_index] = 1

            if self.is_chief and self.display:
                self.env.render()

            curr_comm = [[] for _ in range(self.number_of_agents)]
            for curr_agent in range(self.number_of_agents):
                for from_agent in range(
                        self.amount_of_agents_to_send_message_to):
                    curr_comm[curr_agent].extend([0] * self.message_size)

            for episode_step_count in range(max_episode_length):
                # feedforward pass
                if not already_calculated_actions:
                    action_distribution, message = sess.run(
                        [self.local_AC.policy, self.local_AC.message],
                        feed_dict={
                            self.local_AC.inputs: current_screen,
                            self.local_AC.inputs_comm: curr_comm
                        })
                    actions = [
                        np.random.choice(action_indexes, p=act_distribution)
                        for act_distribution in action_distribution
                    ]
                    # message gauss noise
                    if self.comm_gaussian_noise != 0:
                        for index in range(len(message)):
                            message[index] += np.random.normal(
                                0, self.comm_gaussian_noise)

                    # TODO
                    """message = []
                    for my_obs in current_screen:
                        message.append([my_obs[0]])
                    for index in range(len(actions)):
                        # intersection, failed message or other guy has priority, lets await
                        if current_screen[index][3] == 1 and curr_comm[index][2] >= 0 and current_screen[index][0]==-1:
                            actions[index] = 0
                        else:
                            actions[index] = 1"""
                    """for index in range(len(actions)):
                        # intersection, failed message or other guy has priority, lets await
                        if current_screen[index][3] == 1 and curr_comm[index][10]==0 and current_screen[index][0]!=1:
                            print("Actions: %4.2f %4.2f"%(action_distribution[index][0],action_distribution[index][1]))"""

                    if self.critic_action:
                        for agent in range(self.number_of_agents):
                            actions_one_hot = one_hot_encoding(
                                actions[0:agent] + actions[agent + 1:],
                                self.number_of_actions)
                            arrayed_current_screen_central[
                                agent] = arrayed_current_screen_central[
                                    agent] + actions_one_hot
                    if self.critic_comm:
                        for agent in range(self.number_of_agents):
                            arrayed_current_screen_central[
                                agent] = arrayed_current_screen_central[
                                    agent] + curr_comm[agent]
                already_calculated_actions = False
                value = sess.run(self.local_AC.value,
                                 feed_dict={
                                     self.local_AC.inputs_central:
                                     arrayed_current_screen_central
                                 })

                previous_screen = current_screen
                arrayed_previous_screen_central = arrayed_current_screen_central
                previous_comm = curr_comm

                # Watch environment
                current_screen, reward, terminal, info = self.env.step(actions)

                # debug
                """if episode_count % 500 == 0 and self.is_chief:
                    for index in range(len(actions)):
                        ""if previous_screen[index][3] == 1 and previous_comm[index][10] == 0 and previous_screen[index][0] != 1:
                            print(episode_step_count, "Agent ", index, ", Actions: %4.2f %4.2f" % (
                            action_distribution[index][0], action_distribution[index][1]), "action:",actions[index], ", Reward=",
                                  reward[index], "mess", previous_comm[index])""
                        if previous_screen[index][3] == 1 and previous_screen[index][0] != 1:
                            print(episode_step_count, "Agent ", index, ", Actions: %4.2f %4.2f" % (
                            action_distribution[index][0], action_distribution[index][1]), "action:",actions[index], ", Reward=",
                                  reward[index], "mess", previous_comm[index])"""

                arrayed_current_screen_central = info["state_central"]
                this_turns_comm_map = []
                for i in range(self.number_of_agents):
                    # 50% chance of no comms
                    surviving_comms = current_screen[i][1:4]
                    for index in range(len(surviving_comms)):
                        if random.random(
                        ) < self.comm_delivery_failure_chance:  # chance of failure comms
                            surviving_comms[index] = -1
                    episode_comm_maps[i].append(surviving_comms)
                    this_turns_comm_map.append(surviving_comms)
                    # replace state to just show whether cars were there or not, and not which cars
                    for neighbor_index in range(1, 4):
                        if current_screen[i][neighbor_index] >= 0:
                            current_screen[i][neighbor_index] = 1
                curr_comm = self.output_mess_to_input_mess(
                    message, this_turns_comm_map)

                # jumbles comms
                if self.comm_jumble_chance != 0:
                    for i in range(self.number_of_agents):
                        joint_comm = [0] * self.message_size
                        for index in range(len(curr_comm[i])):
                            joint_comm[
                                index %
                                self.message_size] += curr_comm[i][index]
                        jumble = False
                        for index in range(len(curr_comm[i])):
                            if index % self.message_size == 0:
                                # only jumble messages that got received
                                jumble = curr_comm[i][
                                    index] != 0 and random.random(
                                    ) < self.comm_jumble_chance
                            if jumble:
                                curr_comm[i][index] = joint_comm[
                                    index % self.message_size]

                if self.is_chief and self.display:
                    self.env.render()
                    sleep(0.2)
                episode_reward += sum(
                    reward) if self.spread_rewards else reward

                for i in range(self.number_of_agents):
                    episode_buffer[i].append([
                        previous_screen[i], arrayed_previous_screen_central[i],
                        previous_comm[i], actions[i], message[i],
                        reward[i] if self.spread_rewards else reward,
                        current_screen[i], curr_comm[i], terminal, value[i]
                    ])
                    episode_values[i].append(np.max(value[i]))

                # If the episode hasn't ended, but the experience buffer is full, then we make an update step
                # using that experience rollout.
                if len(episode_buffer[0]) == batch_size and not terminal and \
                                episode_step_count < max_episode_length - 1:
                    # feedforward pass
                    action_distribution, message = sess.run(
                        [self.local_AC.policy, self.local_AC.message],
                        feed_dict={
                            self.local_AC.inputs: current_screen,
                            self.local_AC.inputs_comm: curr_comm
                        })
                    actions = [
                        np.random.choice(action_indexes, p=act_distribution)
                        for act_distribution in action_distribution
                    ]

                    if self.critic_action:
                        for agent in range(self.number_of_agents):
                            actions_one_hot = one_hot_encoding(
                                actions[0:agent] + actions[agent + 1:],
                                self.number_of_actions)
                            arrayed_current_screen_central[
                                agent] = arrayed_current_screen_central[
                                    agent] + actions_one_hot
                    if self.critic_comm:
                        for agent in range(self.number_of_agents):
                            arrayed_current_screen_central[
                                agent] = arrayed_current_screen_central[
                                    agent] + curr_comm[agent]
                    already_calculated_actions = True

                    v1 = sess.run(self.local_AC.value,
                                  feed_dict={
                                      self.local_AC.inputs_central:
                                      arrayed_current_screen_central
                                  })
                    for i in range(self.number_of_agents):
                        partial_obs[i], partial_mess_rec[i], sent_message[i], mgrad_per_received[i], \
                        v_l[i], p_l[i], e_l[i], g_n[i], v_n[i] = \
                            self.train_weights_and_get_comm_gradients(episode_buffer[i], sess, gamma, self.local_AC,
                                                                      bootstrap_value=v1[i][0])

                    if self.comm:
                        mgrad_per_sent = self.input_mloss_to_output_mloss(
                            batch_size - 1, mgrad_per_received,
                            episode_comm_maps)

                        # start a new mini batch with only the last sample in the comm map
                        temp_episode_comm_maps = []
                        for i in range(self.number_of_agents):
                            temp_episode_comm_maps.append(
                                [episode_comm_maps[i][-1]])
                        episode_comm_maps = temp_episode_comm_maps
                        for i in range(self.number_of_agents):
                            self.apply_comm_gradients(partial_obs[i],
                                                      partial_mess_rec[i],
                                                      sent_message[i],
                                                      mgrad_per_sent[i], sess,
                                                      self.local_AC)
                            # exit()
                    # print("Copying global networks to local networks")
                    sess.run(self.update_local_ops)

                    # reset episode buffers. keep last value to be used for t_minus_1 message loss
                    temp_episode_buffer = []
                    for i in range(self.number_of_agents):
                        temp_episode_buffer.append([episode_buffer[i][-1]])
                    episode_buffer = temp_episode_buffer

                # Measure time and increase episode step count
                total_steps += 1
                if total_steps % 2000 == 0:
                    new_clock = time()
                    print(2000.0 / (new_clock - prev_clock), "it/s,   ")
                    prev_clock = new_clock

                # If both prey and predator have acknowledged game is over, then break from episode
                if terminal:
                    break

            # print("0ver ",episode_step_count,episode_reward)
            self.episode_rewards.append(episode_reward)
            self.episode_lengths.append(episode_step_count)
            self.episode_collisions.append(info["collisions"])
            self.episode_stalls.append(info["stalls"])
            self.episode_mean_values.append(np.mean(episode_values))

            # Update the network using the experience buffer at the end of the episode.
            for i in range(self.number_of_agents):
                partial_obs[i], partial_mess_rec[i], sent_message[i], mgrad_per_received[i], \
                v_l[i], p_l[i], e_l[i], g_n[i], v_n[i] = \
                    self.train_weights_and_get_comm_gradients(episode_buffer[i], sess, gamma, self.local_AC)

            if self.comm and len(mgrad_per_received[0]) != 0:
                mgrad_per_sent = self.input_mloss_to_output_mloss(
                    len(mgrad_per_received[0]), mgrad_per_received,
                    episode_comm_maps)

                for i in range(self.number_of_agents):
                    self.apply_comm_gradients(partial_obs[i],
                                              partial_mess_rec[i],
                                              sent_message[i],
                                              mgrad_per_sent[i], sess,
                                              self.local_AC)

            # print("Copying global networks to local networks")
            sess.run(self.update_local_ops)

            # Periodically save gifs of episodes, model parameters, and summary statistics.
            if episode_count % 5 == 0:

                # Save statistics for TensorBoard
                mean_length = np.mean(self.episode_lengths[-5:])
                mean_collisions = np.mean(self.episode_collisions[-5:])
                mean_stalls = np.mean(self.episode_stalls[-5:])
                mean_reward = np.mean(self.episode_rewards[-5:])
                mean_value = np.mean(self.episode_mean_values[-5:])

                if self.is_chief and episode_count % 10 == 0:
                    print("length", mean_length, "reward", mean_reward,
                          "collisions", mean_collisions, "stalls", mean_stalls)

                # Save current model
                if self.is_chief and saver is not None and episode_count % 500 == 0:
                    saver.save(
                        sess, self.model_path + '/model-' +
                        str(episode_count) + '.cptk')
                    print("Saved Model")

                summary = tf.Summary()
                summary.value.add(
                    tag='Perf/Length',
                    simple_value=float(mean_length))  # avg episode length
                summary.value.add(
                    tag='Perf/Reward',
                    simple_value=float(mean_reward))  # avg reward
                summary.value.add(
                    tag='Perf/Value', simple_value=float(
                        mean_value))  # avg episode value_predator
                summary.value.add(tag='Losses/Value Loss',
                                  simple_value=float(
                                      np.mean(v_l)))  # value_loss
                summary.value.add(tag='Losses/Policy Loss',
                                  simple_value=float(
                                      np.mean(p_l)))  # policy_loss
                summary.value.add(tag='Losses/Entropy',
                                  simple_value=float(np.mean(e_l)))  # entropy
                summary.value.add(tag='Losses/Grad Norm',
                                  simple_value=float(
                                      np.mean(g_n)))  # grad_norms
                summary.value.add(tag='Losses/Var Norm',
                                  simple_value=float(
                                      np.mean(v_n)))  # var_norms
                self.summary_writer.add_summary(summary, episode_count)
                self.summary_writer.flush()
            """if self.is_chief and self.message_size == 1:
                vals = sess.run(self.local_AC.message,
                                feed_dict={self.local_AC.inputs: [[-1, -1, -1, 1],
                                                                  [1, -1, -1, 1],
                                                                  [-1, 1, -1, 1],
                                                                  [1, 1, -1, 1],
                                                                  [-1, -1, 1, 1],
                                                                  [1, -1, 1, 1],
                                                                  [-1, 1, 1, 1],
                                                                  [1, 1, 1, 1]
                                                                  ]})
                for i in range(8):
                    comms_evol[i].append(vals[i])
                if self.display:
                    mpl.clf()
                    for i in range(8):
                        mpl.plot(comms_evol[i], color="green" if i % 2 == 0 else "red")
                    mpl.pause(0.001)
                    if episode_count%50==0:
                        print(comms_evol)"""

            # Update episode count
            if self.is_chief:
                episode_count = sess.run(self.increment)
                if episode_count % 50 == 0:
                    print("Global episodes @", episode_count)

                if max_episodes is not None and episode_count > max_episodes:
                    coord.request_stop()
            else:
                episode_count = sess.run(self.global_episodes)

        self.env.close()