示例#1
0
    def work(self, max_episode_length, gamma, sess, coord=None, saver=None, max_episodes=None, batch_size=25):
        episode_count = sess.run(self.global_episodes)
        total_steps = 0
        print("Starting worker " + str(self.number))
        # with sess.as_default(), sess.graph.as_default():
        prev_clock = time()
        if coord is None:
            coord = sess

        #comms_evol = [[], [], [], []]
        #if self.display and self.is_chief:
        #    mpl.ion()
        #    mpl.pause(0.001)

        while not coord.should_stop():
            sess.run(self.update_local_ops)

            episode_buffer = [[] for _ in range(self.number_of_agents)]
            episode_comm_maps = [[] for _ in range(self.number_of_agents)]
            episode_values = [[] for _ in range(self.number_of_agents)]
            episode_reward = 0
            episode_step_count = 0
            action_indexes = list(range(self.env.max_actions))

            v_l, p_l, e_l, g_n, v_n = get_empty_loss_arrays(self.number_of_agents)
            partial_obs = [None for _ in range(self.number_of_agents)]
            partial_mess_rec = [None for _ in range(self.number_of_agents)]
            sent_message = [None for _ in range(self.number_of_agents)]
            mgrad_per_received = [None for _ in range(self.number_of_agents)]

            # start new epi
            current_screen, _ = self.env.reset()
            for i in range(self.number_of_agents):
                episode_comm_maps[i].append(current_screen[i][1:4])

                # replace state to just show whether cars were there or not, and not which cars
                for neighbor_index in range(1, 4):
                    if current_screen[i][neighbor_index] >= 0:
                        current_screen[i][neighbor_index] = 1

            if self.is_chief and self.display:
                self.env.render()

            curr_comm = [[] for _ in range(self.number_of_agents)]
            for curr_agent in range(self.number_of_agents):
                for from_agent in range(self.amount_of_agents_to_send_message_to):
                    curr_comm[curr_agent].extend([0] * self.message_size)

            for episode_step_count in range(max_episode_length):
                # feedforward pass
                action_distribution, value, message = sess.run([self.local_AC.policy, self.local_AC.value,
                                                                self.local_AC.message],
                                                               feed_dict={self.local_AC.inputs: current_screen,
                                                                          self.local_AC.inputs_comm: curr_comm})
                actions = [np.random.choice(action_indexes, p=act_distribution)
                           for act_distribution in action_distribution]

                # TODO hardcoded comms
                #if self.message_size == 1:
                #    for i in range(self.number_of_agents):
                #        message[i] = list(current_screen[i][0])

                # message gauss noise
                if self.comm_gaussian_noise != 0:
                    for index in range(len(message)):
                        message[index] += np.random.normal(0, self.comm_gaussian_noise)

                previous_screen = current_screen
                previous_comm = curr_comm

                # print("OUTPUT MESS", message)

                # Watch environment
                current_screen, reward, terminal, info = self.env.step(actions)
                this_turns_comm_map = []
                for i in range(self.number_of_agents):
                    # 50% chance of no comms
                    surviving_comms = current_screen[i][1:4]
                    for index in range(len(surviving_comms)):
                        if random.random() < self.comm_delivery_failure_chance:  # chance of failure comms
                            surviving_comms[index] = -1
                    episode_comm_maps[i].append(surviving_comms)
                    this_turns_comm_map.append(surviving_comms)

                    # replace state to just show whether cars were there or not, and not which cars
                    for neighbor_index in range(1, 4):
                        if current_screen[i][neighbor_index] >= 0:
                            current_screen[i][neighbor_index] = 1
                curr_comm = self.output_mess_to_input_mess(message, this_turns_comm_map)

                # jumbles comms
                if self.comm_jumble_chance != 0:
                    for i in range(self.number_of_agents):
                        joint_comm = [0] * self.message_size
                        for index in range(len(curr_comm[i])):
                            joint_comm[index % self.message_size] += curr_comm[i][index]
                        jumble = False
                        for index in range(len(curr_comm[i])):
                            if index % self.message_size == 0:
                                # only jumble messages that got received
                                jumble = curr_comm[i][index] != 0 and random.random() < self.comm_jumble_chance
                            if jumble:
                                curr_comm[i][index] = joint_comm[index % self.message_size]

                if self.is_chief and self.display:
                    self.env.render()
                    sleep(0.2)

                episode_reward += sum(reward) if self.spread_rewards else reward

                for i in range(self.number_of_agents):
                    episode_buffer[i].append([previous_screen[i],
                                              previous_comm[i], actions[i], message[i],
                                              reward[i] if self.spread_rewards else reward,
                                              current_screen[i], curr_comm[i], terminal, value[i]])
                    episode_values[i].append(np.max(value[i]))

                # If the episode hasn't ended, but the experience buffer is full, then we make an update step
                # using that experience rollout.
                if len(episode_buffer[0]) == batch_size and not terminal and \
                                episode_step_count < max_episode_length - 1:
                    v1 = sess.run(self.local_AC.value, feed_dict={self.local_AC.inputs: current_screen,
                                                                  self.local_AC.inputs_comm: curr_comm})
                    for i in range(self.number_of_agents):
                        partial_obs[i], partial_mess_rec[i], sent_message[i], mgrad_per_received[i], \
                        v_l[i], p_l[i], e_l[i], g_n[i], v_n[i] = \
                            self.train_weights_and_get_comm_gradients(episode_buffer[i], sess, gamma, self.local_AC,
                                                                      bootstrap_value=v1[i][0])

                    if self.comm:
                        mgrad_per_sent = self.input_mloss_to_output_mloss(batch_size - 1, mgrad_per_received,
                                                                          episode_comm_maps)

                        # start a new mini batch with only the last sample in the comm map
                        temp_episode_comm_maps = []
                        for i in range(self.number_of_agents):
                            temp_episode_comm_maps.append([episode_comm_maps[i][-1]])
                        episode_comm_maps = temp_episode_comm_maps
                        for i in range(self.number_of_agents):
                            self.apply_comm_gradients(partial_obs[i], partial_mess_rec[i],
                                                      sent_message[i], mgrad_per_sent[i], sess, self.local_AC)

                    # print("Copying global networks to local networks")
                    sess.run(self.update_local_ops)

                    # reset episode buffers. keep last value to be used for t_minus_1 message loss
                    temp_episode_buffer = []
                    for i in range(self.number_of_agents):
                        temp_episode_buffer.append([episode_buffer[i][-1]])
                    episode_buffer = temp_episode_buffer
                    # exit()

                # Measure time and increase episode step count
                total_steps += 1
                if total_steps % 2000 == 0:
                    new_clock = time()
                    print(2000.0 / (new_clock - prev_clock), "it/s,   ")
                    prev_clock = new_clock

                # If both prey and predator have acknowledged game is over, then break from episode
                if terminal:
                    break

            # print("0ver ",episode_step_count,episode_reward)
            self.episode_rewards.append(episode_reward)
            self.episode_lengths.append(episode_step_count)
            self.episode_collisions.append(info["collisions"])
            self.episode_stalls.append(info["stalls"])
            self.episode_mean_values.append(np.mean(episode_values))

            # Update the network using the experience buffer at the end of the episode.
            for i in range(self.number_of_agents):
                partial_obs[i], partial_mess_rec[i], sent_message[i], mgrad_per_received[i], \
                v_l[i], p_l[i], e_l[i], g_n[i], v_n[i] = \
                    self.train_weights_and_get_comm_gradients(episode_buffer[i], sess, gamma, self.local_AC)

            if self.comm and len(mgrad_per_received[0]) != 0:
                mgrad_per_sent = self.input_mloss_to_output_mloss(len(mgrad_per_received[0]), mgrad_per_received,
                                                                  episode_comm_maps)

                for i in range(self.number_of_agents):
                    self.apply_comm_gradients(partial_obs[i], partial_mess_rec[i],
                                              sent_message[i], mgrad_per_sent[i], sess, self.local_AC)

            # print("Copying global networks to local networks")
            sess.run(self.update_local_ops)

            # Periodically save gifs of episodes, model parameters, and summary statistics.
            if episode_count % 5 == 0:

                # Save statistics for TensorBoard
                mean_length = np.mean(self.episode_lengths[-5:])
                mean_collisions = np.mean(self.episode_collisions[-5:])
                mean_stalls = np.mean(self.episode_stalls[-5:])
                mean_reward = np.mean(self.episode_rewards[-5:])
                mean_value = np.mean(self.episode_mean_values[-5:])

                """if self.is_chief and self.message_size == 1 and episode_count % 10 == 0:
                    neighbor0 = 0 if random.random() > 0.5 else 1
                    neighbor1 = 0 if random.random() > 0.5 else 1
                    # neighbor2 = 0 if random.random() > 0.5 else 1
                    vals = sess.run(self.local_AC.message,
                                    feed_dict={self.local_AC.inputs: [[-1, neighbor0, neighbor1, 0],
                                                                      [1, neighbor0, neighbor1, 0],
                                                                      [-1, neighbor0, neighbor1, 1],
                                                                      [1, neighbor0, neighbor1, 1]]})
                    comms_evol[0].append(vals[0])
                    comms_evol[1].append(vals[1])
                    comms_evol[2].append(vals[2])
                    comms_evol[3].append(vals[3])
                    if self.display:
                        mpl.clf()
                        mpl.plot(comms_evol[0], color="green")
                        mpl.plot(comms_evol[1], color="red")
                        mpl.plot(comms_evol[2], color="green")
                        mpl.plot(comms_evol[3], color="red")
                        # mpl.show()
                        mpl.pause(0.001)"""
                if self.is_chief and episode_count % 10 == 0:
                    # [vals, flats] = sess.run([self.local_AC.policy,self.local_AC.flattened_inputs_with_comm],
                    #                feed_dict={self.local_AC.inputs: [[-1, -1, -1, 1],[-1, -1, -1, 1],[1, -1, -1, 1],[1, -1, -1, 1]],
                    #                           self.local_AC.inputs_comm: [[0, 0, -1],[0, 0, 1],[0, 0, -1],[0, 0, 1]]})
                    # print(flats, vals)
                    print("length", mean_length, "reward", mean_reward,
                          "collisions", mean_collisions, "stalls", mean_stalls)

                # Save current model
                if self.is_chief and saver is not None and episode_count % 500 == 0:
                    saver.save(sess, self.model_path + '/model-' + str(episode_count) + '.cptk')
                    print("Saved Model")

                summary = tf.Summary()
                #summary.value.add(tag='Perf/Collisions', simple_value=float(mean_collisions))  # avg episode length
                summary.value.add(tag='Perf/Length', simple_value=float(mean_length))  # avg episode length
                summary.value.add(tag='Perf/Reward', simple_value=float(mean_reward))  # avg reward
                summary.value.add(tag='Perf/Value', simple_value=float(mean_value))  # avg episode value_predator
                summary.value.add(tag='Losses/Value Loss', simple_value=float(np.mean(v_l)))  # value_loss
                summary.value.add(tag='Losses/Policy Loss', simple_value=float(np.mean(p_l)))  # policy_loss
                summary.value.add(tag='Losses/Entropy', simple_value=float(np.mean(e_l)))  # entropy
                summary.value.add(tag='Losses/Grad Norm', simple_value=float(np.mean(g_n)))  # grad_norms
                summary.value.add(tag='Losses/Var Norm', simple_value=float(np.mean(v_n)))  # var_norms
                self.summary_writer.add_summary(summary, episode_count)
                self.summary_writer.flush()

            # Update episode count
            if self.is_chief:
                episode_count = sess.run(self.increment)
                if episode_count % 50 == 0:
                    print("Global episodes @", episode_count)

                if max_episodes is not None and episode_count > max_episodes:
                    coord.request_stop()
            else:
                episode_count = sess.run(self.global_episodes)

        self.env.close()

        """if self.is_chief:
示例#2
0
    def work(self,
             max_episode_length,
             gamma,
             sess,
             coord=None,
             saver=None,
             max_episodes=None,
             batch_size=25):
        episode_count = sess.run(self.global_episodes)
        total_steps = 0
        print("Starting worker " + str(self.number))
        # with sess.as_default(), sess.graph.as_default():
        prev_clock = time()
        if coord is None:
            coord = sess

        already_calculated_actions = False
        stats1, stats2 = [], []
        while not coord.should_stop():
            sess.run(self.update_local_ops)

            episode_buffer = [[] for _ in range(self.number_of_agents)]
            episode_comm_maps = [[] for _ in range(self.number_of_agents)]
            episode_values = [[] for _ in range(self.number_of_agents)]
            episode_reward = 0
            episode_step_count = 0

            v_l, p_l, e_l, g_n, v_n = get_empty_loss_arrays(
                self.number_of_agents)
            partial_obs = [None for _ in range(self.number_of_agents)]
            partial_mess_rec = [None for _ in range(self.number_of_agents)]
            sent_message = [None for _ in range(self.number_of_agents)]
            mgrad_per_received = [None for _ in range(self.number_of_agents)]

            # start new epi
            current_screen = self.env.reset()
            # print(current_screen)
            for i in range(self.number_of_agents):
                if self.s_size == 6:
                    current_screen[i] = np.hstack([
                        current_screen[i][0:4],
                        current_screen[i][4 + i * 2:6 + i * 2]
                    ])
                else:
                    current_screen[i] = current_screen[i][0:self.s_size]

            current_screen_central = np.hstack(current_screen)
            arrayed_current_screen_central = [
                current_screen_central for _ in range(self.number_of_agents)
            ]
            for i in range(self.number_of_agents):
                comm_map = list(range(self.number_of_agents))
                comm_map.remove(i)
                episode_comm_maps[i].append(comm_map)

            if self.is_chief and self.display:
                self.env.render()

            curr_comm = [[] for _ in range(self.number_of_agents)]
            for curr_agent in range(self.number_of_agents):
                for from_agent in range(self.number_of_agents - 1):
                    curr_comm[curr_agent].extend([0] * self.message_size)

            stats1_temp = 0

            for episode_step_count in range(max_episode_length):
                # feedforward pass
                if not already_calculated_actions:
                    action_distribution, message = sess.run(
                        [self.local_AC.policy, self.local_AC.message],
                        feed_dict={
                            self.local_AC.inputs: current_screen,
                            self.local_AC.inputs_comm: curr_comm
                        })
                    actions = [
                        np.random.choice(self.action_indexes,
                                         p=act_distribution)
                        for act_distribution in action_distribution
                    ]
                    actions_one_hot = [
                        self.actions_one_hot[act] for act in actions
                    ]

                    # message gauss noise
                    if self.comm_gaussian_noise != 0:
                        for index in range(len(message)):
                            message[index] += np.random.normal(
                                0, self.comm_gaussian_noise)

                    if self.critic_action:
                        for agent in range(self.number_of_agents):
                            actions_one_hot = one_hot_encoding(
                                actions[0:agent] + actions[agent + 1:],
                                self.number_of_actions)
                            arrayed_current_screen_central[
                                agent] = arrayed_current_screen_central[
                                    agent] + actions_one_hot
                    if self.critic_comm:
                        for agent in range(self.number_of_agents):
                            arrayed_current_screen_central[
                                agent] = arrayed_current_screen_central[
                                    agent] + curr_comm[agent]
                already_calculated_actions = False
                value = sess.run(self.local_AC.value,
                                 feed_dict={
                                     self.local_AC.inputs_central:
                                     arrayed_current_screen_central
                                 })

                previous_screen = current_screen
                arrayed_previous_screen_central = arrayed_current_screen_central
                previous_comm = curr_comm

                # Watch environment
                current_screen, reward, terminal, info = self.env.step(
                    actions_one_hot)
                for info_agent in info['n']:
                    stats1_temp += (info_agent[1] - 1) / 2

                for i in range(self.number_of_agents):
                    current_screen[i] = current_screen[i][0:self.s_size]
                terminal = np.sum(reward) > (
                    -0.25) * self.number_of_agents or stats1_temp > 5
                current_screen_central = np.hstack(current_screen)
                arrayed_current_screen_central = [
                    current_screen_central
                    for _ in range(self.number_of_agents)
                ]

                this_turns_comm_map = []
                for i in range(self.number_of_agents):
                    # 50% chance of no comms
                    surviving_comms = list(range(self.number_of_agents))
                    surviving_comms.remove(i)
                    for index in range(len(surviving_comms)):
                        if random.random(
                        ) < self.comm_delivery_failure_chance:  # chance of failure comms
                            surviving_comms[index] = -1
                    episode_comm_maps[i].append(surviving_comms)
                    this_turns_comm_map.append(surviving_comms)
                curr_comm = self.output_mess_to_input_mess(
                    message, this_turns_comm_map)

                if self.is_chief and self.display:
                    self.env.render()
                    sleep(0.2)
                episode_reward += sum(
                    reward) if self.spread_rewards else reward

                # jumbles comms
                if self.comm_jumble_chance != 0:
                    for i in range(self.number_of_agents):
                        joint_comm = [0] * self.message_size
                        for index in range(len(curr_comm[i])):
                            joint_comm[
                                index %
                                self.message_size] += curr_comm[i][index]
                        jumble = False
                        for index in range(len(curr_comm[i])):
                            if index % self.message_size == 0:
                                # only jumble messages that got received
                                jumble = curr_comm[i][
                                    index] != 0 and random.random(
                                    ) < self.comm_jumble_chance
                            if jumble:
                                curr_comm[i][index] = joint_comm[
                                    index % self.message_size]

                for i in range(self.number_of_agents):
                    episode_buffer[i].append([
                        previous_screen[i], arrayed_previous_screen_central[i],
                        previous_comm[i], actions[i], message[i],
                        reward[i] if self.spread_rewards else reward,
                        current_screen[i], curr_comm[i], terminal, value[i]
                    ])
                    episode_values[i].append(np.max(value[i]))

                # If the episode hasn't ended, but the experience buffer is full, then we make an update step
                # using that experience rollout.
                if len(episode_buffer[0]) == batch_size and not terminal and \
                                episode_step_count < max_episode_length - 1:
                    # feedforward pass
                    action_distribution, message = sess.run(
                        [self.local_AC.policy, self.local_AC.message],
                        feed_dict={
                            self.local_AC.inputs: current_screen,
                            self.local_AC.inputs_comm: curr_comm
                        })
                    actions = [
                        np.random.choice(self.action_indexes,
                                         p=act_distribution)
                        for act_distribution in action_distribution
                    ]
                    actions_one_hot = [
                        self.actions_one_hot[act] for act in actions
                    ]

                    if self.critic_action:
                        for agent in range(self.number_of_agents):
                            actions_one_hot = one_hot_encoding(
                                actions[0:agent] + actions[agent + 1:],
                                self.number_of_actions)
                            arrayed_current_screen_central[
                                agent] = arrayed_current_screen_central[
                                    agent] + actions_one_hot
                    if self.critic_comm:
                        for agent in range(self.number_of_agents):
                            arrayed_current_screen_central[
                                agent] = arrayed_current_screen_central[
                                    agent] + curr_comm[agent]
                    already_calculated_actions = True
                    v1 = sess.run(self.local_AC.value,
                                  feed_dict={
                                      self.local_AC.inputs_central:
                                      arrayed_current_screen_central
                                  })

                    for i in range(self.number_of_agents):
                        partial_obs[i], partial_mess_rec[i], sent_message[i], mgrad_per_received[i], \
                        v_l[i], p_l[i], e_l[i], g_n[i], v_n[i] = \
                            self.train_weights_and_get_comm_gradients(episode_buffer[i], sess, gamma, self.local_AC,
                                                                      bootstrap_value=v1[i][0])

                    if self.comm:
                        mgrad_per_sent = self.input_mloss_to_output_mloss(
                            batch_size - 1, mgrad_per_received,
                            episode_comm_maps)

                        # start a new mini batch with only the last sample in the comm map
                        temp_episode_comm_maps = []
                        for i in range(self.number_of_agents):
                            temp_episode_comm_maps.append(
                                [episode_comm_maps[i][-1]])
                        episode_comm_maps = temp_episode_comm_maps
                        for i in range(self.number_of_agents):
                            self.apply_comm_gradients(partial_obs[i],
                                                      partial_mess_rec[i],
                                                      sent_message[i],
                                                      mgrad_per_sent[i], sess,
                                                      self.local_AC)

                    # print("Copying global networks to local networks")
                    sess.run(self.update_local_ops)

                    # reset episode buffers. keep last value to be used for t_minus_1 message loss
                    temp_episode_buffer = []
                    for i in range(self.number_of_agents):
                        temp_episode_buffer.append([episode_buffer[i][-1]])
                    episode_buffer = temp_episode_buffer

                # Measure time and increase episode step count
                total_steps += 1
                if total_steps % 2000 == 0:
                    new_clock = time()
                    print(2000.0 / (new_clock - prev_clock), "it/s,   ")
                    prev_clock = new_clock

                # If both prey and predator have acknowledged game is over, then break from episode
                if terminal:
                    break

            stats1.append(stats1_temp)
            stats2.append(-reward[0] / self.number_of_agents)

            # print("0ver ",episode_step_count,episode_reward)
            self.episode_rewards.append(episode_reward)
            self.episode_lengths.append(episode_step_count)
            self.episode_mean_values.append(np.mean(episode_values))

            # Update the network using the experience buffer at the end of the episode.
            for i in range(self.number_of_agents):
                partial_obs[i], partial_mess_rec[i], sent_message[i], mgrad_per_received[i], \
                v_l[i], p_l[i], e_l[i], g_n[i], v_n[i] = \
                    self.train_weights_and_get_comm_gradients(episode_buffer[i], sess, gamma, self.local_AC)

            if self.comm and len(mgrad_per_received[0]) != 0:
                mgrad_per_sent = self.input_mloss_to_output_mloss(
                    len(mgrad_per_received[0]), mgrad_per_received,
                    episode_comm_maps)

                for i in range(self.number_of_agents):
                    self.apply_comm_gradients(partial_obs[i],
                                              partial_mess_rec[i],
                                              sent_message[i],
                                              mgrad_per_sent[i], sess,
                                              self.local_AC)

            # print("Copying global networks to local networks")
            sess.run(self.update_local_ops)

            # Periodically save gifs of episodes, model parameters, and summary statistics.
            if episode_count % 5 == 0:

                # Save statistics for TensorBoard
                mean_length = np.mean(self.episode_lengths[-5:])
                mean_reward = np.mean(self.episode_rewards[-5:])
                mean_value = np.mean(self.episode_mean_values[-5:])

                if self.is_chief and episode_count % 10 == 0:
                    print("length", mean_length, "reward", np.max(reward))

                # Save current model
                if self.is_chief and saver is not None and episode_count % 500 == 0:
                    saver.save(
                        sess, self.model_path + '/model-' +
                        str(episode_count) + '.cptk')
                    print("Saved Model")

                summary = tf.Summary()
                summary.value.add(
                    tag='Perf/Length',
                    simple_value=float(mean_length))  # avg episode length
                summary.value.add(
                    tag='Perf/Reward',
                    simple_value=float(mean_reward))  # avg reward
                summary.value.add(
                    tag='Perf/Value', simple_value=float(
                        mean_value))  # avg episode value_predator
                summary.value.add(tag='Losses/Value Loss',
                                  simple_value=float(
                                      np.mean(v_l)))  # value_loss
                summary.value.add(tag='Losses/Policy Loss',
                                  simple_value=float(
                                      np.mean(p_l)))  # policy_loss
                summary.value.add(tag='Losses/Entropy',
                                  simple_value=float(np.mean(e_l)))  # entropy
                summary.value.add(tag='Losses/Grad Norm',
                                  simple_value=float(
                                      np.mean(g_n)))  # grad_norms
                summary.value.add(tag='Losses/Var Norm',
                                  simple_value=float(
                                      np.mean(v_n)))  # var_norms
                self.summary_writer.add_summary(summary, episode_count)
                self.summary_writer.flush()

            # Update episode count
            if self.is_chief:
                episode_count = sess.run(self.increment)
                if episode_count % 50 == 0:
                    print("Global episodes @", episode_count)

                if max_episodes is not None and episode_count > max_episodes:
                    coord.request_stop()
            else:
                episode_count = sess.run(self.global_episodes)

        self.env.close()
        print(stats1)
        print(stats2)
        print(np.mean(stats1), np.mean(stats2))
示例#3
0
    def work(self,
             max_episode_length,
             gamma,
             sess,
             coord=None,
             saver=None,
             max_episodes=None,
             batch_size=25):
        episode_count = sess.run(self.global_episodes)
        total_steps = 0
        print("Starting worker " + str(self.number))
        # with sess.as_default(), sess.graph.as_default():
        prev_clock = time()
        if coord is None:
            coord = sess

        already_calculated_actions = False
        while not coord.should_stop():
            sess.run(self.update_local_ops)

            episode_buffer = [[] for _ in range(self.number_of_agents)]
            episode_comm_maps = [[] for _ in range(self.number_of_agents)]
            episode_values = [[] for _ in range(self.number_of_agents)]
            episode_reward = 0
            episode_step_count = 0
            action_indexes = list(range(self.env.max_actions))

            v_l, p_l, e_l, g_n, v_n = get_empty_loss_arrays(
                self.number_of_agents)
            partial_obs = [None for _ in range(self.number_of_agents)]
            partial_mess_rec = [None for _ in range(self.number_of_agents)]
            sent_message = [None for _ in range(self.number_of_agents)]
            mgrad_per_received = [None for _ in range(self.number_of_agents)]

            # start new epi
            current_screen, info = self.env.reset()
            previous_screen = [None, None, None]
            previous_actions = [-1, -1, -1]
            arrayed_current_screen_central = [
                info["state_central"] for _ in range(self.number_of_agents)
            ]
            for i in range(self.number_of_agents):
                comm_map = list(range(self.number_of_agents))
                comm_map.remove(i)
                episode_comm_maps[i].append(comm_map)

            curr_comm = [[] for _ in range(self.number_of_agents)]
            for curr_agent in range(self.number_of_agents):
                for from_agent in range(
                        self.amount_of_agents_to_send_message_to):
                    curr_comm[curr_agent].extend([0] * self.message_size)

            for episode_step_count in range(max_episode_length):
                # print(current_screen, arrayed_current_screen_central)
                # feedforward pass
                # print(current_screen)
                # print(curr_comm)
                actions = [-1, -1, -1]
                message = [[], [], []]
                action_distribution = [None, None, None]
                for i in range(self.number_of_agents):
                    if current_screen[i] is not None:
                        [action_distribution[i]], [message[i]] = \
                            sess.run([self.local_AC.policy, self.local_AC.message],
                                     feed_dict={self.local_AC.inputs: [current_screen[i]],
                                                self.local_AC.inputs_comm: [curr_comm[i]]})
                        """if np.isnan(action_distribution[i]).any():
                            print(self.name, "Found NaN! Input:", current_screen[i], "Output:", action_distribution[i])
                            actions[i] = 0
                            exit()
                        else:"""
                        actions[i] = np.random.choice(action_indexes,
                                                      p=action_distribution[i])
                        """# TODO
                        if current_screen[i][12] < current_screen[i][13] and \
                                        current_screen[i][12] < current_screen[i][14]:
                            actions[i] = 1
                            #print("kicker:", i)
                        else:
                            actions[i] = 4"""
                """if np.isnan(arrayed_current_screen_central).any():
                    print(self.name, "Found NaN, arrayed_current_screen_central:")
                    print(arrayed_current_screen_central)"""
                value = sess.run(self.local_AC.value,
                                 feed_dict={
                                     self.local_AC.inputs_central:
                                     arrayed_current_screen_central
                                 })
                """if np.isnan(value).any():
                    print(self.name, "Found NaN, value:")
                    print(arrayed_current_screen_central)
                    print(value)
                    for v in value:
                        if np.isnan(v[0]):
                            v[0] = 0"""

                for i in range(self.number_of_agents):
                    if current_screen[i] is not None:
                        previous_screen[i] = current_screen[i]
                        previous_actions[i] = actions[i]
                previous_arrayed_screen_central = arrayed_current_screen_central
                previous_comm = curr_comm

                # Watch environment
                current_screen, reward, terminal, info = self.env.step(actions)

                for i in range(self.number_of_agents):
                    if len(current_screen[i]) == 1:
                        current_screen[i] = None

                arrayed_current_screen_central = [
                    info["state_central"] for _ in range(self.number_of_agents)
                ]
                this_turns_comm_map = []
                for i in range(self.number_of_agents):
                    surviving_comms = list(range(self.number_of_agents))
                    surviving_comms.remove(i)
                    episode_comm_maps[i].append(surviving_comms)
                    this_turns_comm_map.append(surviving_comms)
                curr_comm = self.output_mess_to_input_mess(
                    message, this_turns_comm_map)

                episode_reward += sum(
                    reward) if self.spread_rewards else reward

                for i in range(self.number_of_agents):
                    if current_screen[i] is not None:
                        #print( previous_actions[i])
                        episode_buffer[i].append([
                            previous_screen[i],
                            previous_arrayed_screen_central[i],
                            previous_comm[i], previous_actions[i], message[i],
                            reward[i] if self.spread_rewards else reward,
                            current_screen[i], curr_comm[i], terminal, value[i]
                        ])
                        episode_values[i].append(np.max(value[i]))

                # If the episode hasn't ended, but the experience buffer is full, then we make an update step
                # using that experience rollout.
                if (len(episode_buffer[0]) == batch_size or len(episode_buffer[1]) == batch_size or
                            len(episode_buffer[2]) == batch_size) and not terminal and \
                                episode_step_count < max_episode_length - 1:
                    """if np.isnan(arrayed_current_screen_central).any():
                        print(self.name, "Found NaN, arrayed_current_screen_central:")
                        print(arrayed_current_screen_central)"""
                    v1 = sess.run(self.local_AC.value,
                                  feed_dict={
                                      self.local_AC.inputs_central:
                                      arrayed_current_screen_central
                                  })
                    """if np.isnan(v1).any():
                        print(self.name, "Found NaN, v1:")
                        print(arrayed_current_screen_central)
                        print(v1)
                        for v in v1:
                            if np.isnan(v[0]):
                                v[0] = 0"""
                    for i in range(self.number_of_agents):
                        if len(episode_buffer[i]) == batch_size:
                            # print("optimizing",i,"with",len(episode_buffer[i]),"samples")
                            partial_obs[i], partial_mess_rec[i], sent_message[i], mgrad_per_received[i], \
                            v_l[i], p_l[i], e_l[i], g_n[i], v_n[i] = \
                                self.train_weights_and_get_comm_gradients(episode_buffer[i], sess, gamma, self.local_AC,
                                                                          bootstrap_value=v1[i][0])

                    if self.comm:
                        mgrad_per_sent = self.input_mloss_to_output_mloss(
                            batch_size - 1, mgrad_per_received,
                            episode_comm_maps)

                        # start a new mini batch with only the last sample in the comm map
                        temp_episode_comm_maps = []
                        for i in range(self.number_of_agents):
                            temp_episode_comm_maps.append(
                                [episode_comm_maps[i][-1]])
                        episode_comm_maps = temp_episode_comm_maps
                        for i in range(self.number_of_agents):
                            self.apply_comm_gradients(partial_obs[i],
                                                      partial_mess_rec[i],
                                                      sent_message[i],
                                                      mgrad_per_sent[i], sess,
                                                      self.local_AC)

                    # print("Copying global networks to local networks")
                    sess.run(self.update_local_ops)

                    # reset episode buffers. keep last value to be used for t_minus_1 message loss
                    temp_episode_buffer = []
                    for i in range(self.number_of_agents):
                        if len(episode_buffer[i]) == batch_size:
                            temp_episode_buffer.append([episode_buffer[i][-1]])
                        else:
                            temp_episode_buffer.append(episode_buffer[i])
                    episode_buffer = temp_episode_buffer

                # Measure time and increase episode step count
                total_steps += 1
                if total_steps % 2000 == 0:
                    new_clock = time()
                    print(2000.0 / (new_clock - prev_clock), "it/s,   ")
                    prev_clock = new_clock

                # If both prey and predator have acknowledged game is over, then break from episode
                if terminal:
                    print(self.name, "terminal @", episode_step_count)
                    break

            # print("0ver ",episode_step_count,episode_reward)
            self.episode_rewards.append(episode_reward)
            self.episode_lengths.append(episode_step_count)
            self.episode_mean_values.append(
                np.mean([np.mean(x) for x in episode_values]))

            # Update the network using the experience buffer at the end of the episode.
            """print(self.name, "final opti", len(episode_buffer[0]), len(episode_buffer[1]), len(episode_buffer[2]))"""
            for i in range(self.number_of_agents):
                partial_obs[i], partial_mess_rec[i], sent_message[i], mgrad_per_received[i], \
                v_l[i], p_l[i], e_l[i], g_n[i], v_n[i] = \
                    self.train_weights_and_get_comm_gradients(episode_buffer[i], sess, gamma, self.local_AC)

            if self.comm and episode_step_count != 0:
                mgrad_per_sent = self.input_mloss_to_output_mloss(
                    len(mgrad_per_received[0]), mgrad_per_received,
                    episode_comm_maps)

                for i in range(self.number_of_agents):
                    self.apply_comm_gradients(partial_obs[i],
                                              partial_mess_rec[i],
                                              sent_message[i],
                                              mgrad_per_sent[i], sess,
                                              self.local_AC)

            # print("Copying global networks to local networks")
            sess.run(self.update_local_ops)

            # Periodically save gifs of episodes, model parameters, and summary statistics.
            if episode_count % 5 == 0:

                # Save statistics for TensorBoard
                mean_length = np.mean(self.episode_lengths[-5:])
                mean_reward = np.mean(self.episode_rewards[-5:])
                mean_value = np.mean(self.episode_mean_values[-5:])

                if self.is_chief and episode_count % 10 == 0:
                    print("length", mean_length, "reward", mean_reward)

                # Save current model
                if self.is_chief and saver is not None and episode_count % 500 == 0:
                    saver.save(
                        sess, self.model_path + '/model-' +
                        str(episode_count) + '.cptk')
                    print("Saved Model")

                summary = tf.Summary()
                summary.value.add(
                    tag='Perf/Length',
                    simple_value=float(mean_length))  # avg episode length
                summary.value.add(
                    tag='Perf/Reward',
                    simple_value=float(mean_reward))  # avg reward
                summary.value.add(
                    tag='Perf/Value', simple_value=float(
                        mean_value))  # avg episode value_predator
                summary.value.add(tag='Losses/Value Loss',
                                  simple_value=float(
                                      np.mean(v_l)))  # value_loss
                summary.value.add(tag='Losses/Policy Loss',
                                  simple_value=float(
                                      np.mean(p_l)))  # policy_loss
                summary.value.add(tag='Losses/Entropy',
                                  simple_value=float(np.mean(e_l)))  # entropy
                summary.value.add(tag='Losses/Grad Norm',
                                  simple_value=float(
                                      np.mean(g_n)))  # grad_norms
                summary.value.add(tag='Losses/Var Norm',
                                  simple_value=float(
                                      np.mean(v_n)))  # var_norms
                self.summary_writer.add_summary(summary, episode_count)
                self.summary_writer.flush()

            # Update episode count
            if self.is_chief:
                episode_count = sess.run(self.increment)
                if episode_count % 50 == 0:
                    print("Global episodes @", episode_count)

                if max_episodes is not None and episode_count > max_episodes:
                    coord.request_stop()
            else:
                episode_count = sess.run(self.global_episodes)

        self.env.close()