示例#1
0
 def create_testing_environment(self, assay):
     """
     Creates the testing environment as specified  by apparatus mode and given assays.
     :return:
     """
     if assay["stimulus paradigm"] == "Projection":
         self.simulation = ControlledStimulusEnvironment(
             self.environment_params,
             assay["stimuli"],
             self.realistic_bouts,
             tethered=assay["Tethered"],
             set_positions=assay["set positions"],
             random=assay["random positions"],
             moving=assay["moving"],
             reset_each_step=assay["reset"],
             reset_interval=assay["reset interval"],
             background=assay["background"])
     elif assay["stimulus paradigm"] == "Naturalistic":
         self.simulation = NaturalisticEnvironment(
             self.environment_params,
             self.realistic_bouts,
             collisions=assay["collisions"])
     else:
         self.simulation = NaturalisticEnvironment(self.environment_params,
                                                   self.realistic_bouts)
示例#2
0
 def switch_configuration(self, next_point):
     self.configuration_index = int(next_point)
     self.switched_configuration = True
     print(
         f"{self.trial_id}: Changing configuration to configuration {self.configuration_index}"
     )
     self.params, self.env = self.load_configuration_files()
     self.simulation = NaturalisticEnvironment(self.env,
                                               self.realistic_bouts)
示例#3
0
def produce_values_for_a_m_value(m):
    s = NaturalisticEnvironment(env, draw_screen=False, fish_mass=m)
    data_a = produce_calibration_curve_data(s)

    for i in range(100):
        s = NaturalisticEnvironment(env, draw_screen=False, fish_mass=m)
        data_b = produce_calibration_curve_data(s)
        data_a = produce_weighted_average(data_a, data_b, i)

    return data_a
示例#4
0
    def __init__(self, model_name, trial_number, assay_config_name,
                 learning_params, environment_params, total_steps,
                 episode_number, assays, realistic_bouts, memory_fraction,
                 using_gpu, set_random_seed):
        """
        Runs a set of assays provided by the run configuraiton.
        """

        # Set random seed
        if set_random_seed:
            np.random.seed(404)

        # Names and Directories
        self.model_id = f"{model_name}-{trial_number}"
        self.model_location = f"./Training-Output/{self.model_id}"
        self.data_save_location = f"./Assay-Output/{self.model_id}"

        # Configurations
        self.assay_configuration_id = assay_config_name
        self.learning_params = learning_params
        self.environment_params = environment_params
        self.assays = assays

        # Basic Parameters
        self.using_gpu = using_gpu
        self.realistic_bouts = realistic_bouts
        self.memory_fraction = memory_fraction

        # Network Parameters
        self.saver = None
        self.network = None
        self.init = None
        self.sess = None

        # Simulation
        self.simulation = NaturalisticEnvironment(self.environment_params,
                                                  self.realistic_bouts)
        self.step_number = 0

        # Data
        self.metadata = {
            "Total Episodes": episode_number,
            "Total Steps": total_steps,
        }
        self.frame_buffer = []
        self.assay_output_data_format = None
        self.assay_output_data = []
        self.output_data = {}
        self.episode_summary_data = None

        # Hacky fix for h5py problem:
        self.last_position_dim = self.environment_params["prey_num"]
        self.stimuli_data = []
示例#5
0
def load_network_variables(model_name, conf_name):
    learning, env = load_configuration_files(f"{conf_name}")
    simulation = NaturalisticEnvironment(env, False)
    model_location = f"../../Training-Output/{model_name}"

    with tf.Session() as sess:
        cell = tf.nn.rnn_cell.LSTMCell(num_units=learning["rnn_dim"], state_is_tuple=True)
        internal_states = sum([1 for x in [env['hunger'], env['stress']] if x is True]) + 1
        network = QNetwork(simulation=simulation,
                           rnn_dim=learning["rnn_dim"],
                           rnn_cell=cell,
                           my_scope='main',
                           internal_states=internal_states,
                           num_actions=learning["num_actions"],
                           learning_rate=learning["learning_rate"])
        saver = tf.train.Saver(max_to_keep=5)
        init = tf.global_variables_initializer()
        checkpoint = tf.train.get_checkpoint_state(model_location)
        saver.restore(sess, checkpoint.model_checkpoint_path)
        vars = tf.trainable_variables()
        vals = sess.run(vars)
        sorted_vars = {}
        for var, val in zip(vars, vals):
            sorted_vars[str(var.name)] = val
        return sorted_vars
示例#6
0
class AssayService:
    def __init__(self, model_name, trial_number, assay_config_name,
                 learning_params, environment_params, total_steps,
                 episode_number, assays, realistic_bouts, memory_fraction,
                 using_gpu, set_random_seed):
        """
        Runs a set of assays provided by the run configuraiton.
        """

        # Set random seed
        if set_random_seed:
            np.random.seed(404)

        # Names and Directories
        self.model_id = f"{model_name}-{trial_number}"
        self.model_location = f"./Training-Output/{self.model_id}"
        self.data_save_location = f"./Assay-Output/{self.model_id}"

        # Configurations
        self.assay_configuration_id = assay_config_name
        self.learning_params = learning_params
        self.environment_params = environment_params
        self.assays = assays

        # Basic Parameters
        self.using_gpu = using_gpu
        self.realistic_bouts = realistic_bouts
        self.memory_fraction = memory_fraction

        # Network Parameters
        self.saver = None
        self.network = None
        self.init = None
        self.sess = None

        # Simulation
        self.simulation = NaturalisticEnvironment(self.environment_params,
                                                  self.realistic_bouts)
        self.step_number = 0

        # Data
        self.metadata = {
            "Total Episodes": episode_number,
            "Total Steps": total_steps,
        }
        self.frame_buffer = []
        self.assay_output_data_format = None
        self.assay_output_data = []
        self.output_data = {}
        self.episode_summary_data = None

        # Hacky fix for h5py problem:
        self.last_position_dim = self.environment_params["prey_num"]
        self.stimuli_data = []

    def create_network(self):
        internal_states = sum([
            1 for x in [
                self.environment_params['hunger'],
                self.environment_params['stress']
            ] if x is True
        ]) + 1

        cell = tf.nn.rnn_cell.LSTMCell(
            num_units=self.learning_params['rnn_dim'], state_is_tuple=True)
        network = QNetwork(simulation=self.simulation,
                           rnn_dim=self.learning_params['rnn_dim'],
                           rnn_cell=cell,
                           my_scope='main',
                           num_actions=self.learning_params['num_actions'],
                           internal_states=internal_states,
                           learning_rate=self.learning_params['learning_rate'],
                           extra_layer=self.learning_params['extra_rnn'])
        return network

    def create_testing_environment(self, assay):
        """
        Creates the testing environment as specified  by apparatus mode and given assays.
        :return:
        """
        if assay["stimulus paradigm"] == "Projection":
            self.simulation = ControlledStimulusEnvironment(
                self.environment_params,
                assay["stimuli"],
                self.realistic_bouts,
                tethered=assay["Tethered"],
                set_positions=assay["set positions"],
                random=assay["random positions"],
                moving=assay["moving"],
                reset_each_step=assay["reset"],
                reset_interval=assay["reset interval"],
                background=assay["background"])
        elif assay["stimulus paradigm"] == "Naturalistic":
            self.simulation = NaturalisticEnvironment(
                self.environment_params,
                self.realistic_bouts,
                collisions=assay["collisions"])
        else:
            self.simulation = NaturalisticEnvironment(self.environment_params,
                                                      self.realistic_bouts)

    def run(self):
        if self.using_gpu:
            options = tf.GPUOptions(
                per_process_gpu_memory_fraction=self.memory_fraction)
        else:
            options = None

        if options:
            with tf.Session(config=tf.ConfigProto(
                    gpu_options=options)) as self.sess:
                self._run()
        else:
            with tf.Session() as self.sess:
                self._run()

    def _run(self):
        self.network = self.create_network()
        self.saver = tf.train.Saver(max_to_keep=5)
        self.init = tf.global_variables_initializer()
        checkpoint = tf.train.get_checkpoint_state(self.model_location)
        self.saver.restore(self.sess, checkpoint.model_checkpoint_path)
        print("Model loaded")
        for assay in self.assays:
            if assay["ablations"]:
                self.ablate_units(assay["ablations"])
            self.create_output_data_storage(assay)
            self.create_testing_environment(assay)
            self.perform_assay(assay)
            if assay["save stimuli"]:
                self.save_stimuli_data(assay)
            # self.save_assay_results(assay)
            self.save_hdf5_data(assay)
        self.save_metadata()
        self.save_episode_data()

    def create_output_data_storage(self, assay):
        self.output_data = {key: [] for key in assay["recordings"]}
        self.output_data["step"] = []

    def ablate_units(self, unit_indexes):
        for unit in unit_indexes:
            if unit < 256:
                output = self.sess.graph.get_tensor_by_name('mainaw:0')
                new_tensor = output.eval()
                new_tensor[unit] = np.array([0 for i in range(10)])
                self.sess.run(tf.assign(output, new_tensor))
            else:
                output = self.sess.graph.get_tensor_by_name('mainvw:0')
                new_tensor = output.eval()
                new_tensor[unit - 256] = np.array([0])
                self.sess.run(tf.assign(output, new_tensor))

    def perform_assay(self, assay):
        self.assay_output_data_format = {
            key: None
            for key in assay["recordings"]
        }

        self.simulation.reset()
        rnn_state = (np.zeros([1, self.network.rnn_dim]),
                     np.zeros([1, self.network.rnn_dim])
                     )  # Reset RNN hidden state
        sa = np.zeros((1, 128))

        o, r, internal_state, d, self.frame_buffer = self.simulation.simulation_step(
            action=3,
            frame_buffer=self.frame_buffer,
            save_frames=True,
            activations=(sa, ))
        a = 0
        self.step_number = 0
        while self.step_number < assay["duration"]:
            if assay["reset"] and self.step_number % assay[
                    "reset interval"] == 0:
                rnn_state = (np.zeros([1, self.network.rnn_dim]),
                             np.zeros([1, self.network.rnn_dim])
                             )  # Reset RNN hidden state
            self.step_number += 1

            o, a, r, internal_state, o1, d, rnn_state = self.step_loop(
                o=o, internal_state=internal_state, a=a, rnn_state=rnn_state)
            o = o1

            if d:
                break

    def step_loop(self, o, internal_state, a, rnn_state):
        chosen_a, updated_rnn_state, rnn2_state, sa, sv, conv1l, conv2l, conv3l, conv4l, conv1r, conv2r, conv3r, conv4r, o2 = \
            self.sess.run(
                [self.network.predict, self.network.rnn_state, self.network.rnn_state2, self.network.streamA, self.network.streamV,
                 self.network.conv1l, self.network.conv2l, self.network.conv3l, self.network.conv4l,
                 self.network.conv1r, self.network.conv2r, self.network.conv3r, self.network.conv4r,
                 [self.network.ref_left_eye, self.network.ref_right_eye],
                 ],
                feed_dict={self.network.observation: o,
                           self.network.internal_state: internal_state,
                           self.network.prev_actions: [a],
                           self.network.trainLength: 1,
                           self.network.state_in: rnn_state,
                           self.network.batch_size: 1,
                           self.network.exp_keep: 1.0})
        chosen_a = chosen_a[0]
        o1, given_reward, internal_state, d, self.frame_buffer = self.simulation.simulation_step(
            action=chosen_a,
            frame_buffer=self.frame_buffer,
            save_frames=True,
            activations=(sa, ))
        fish_angle = self.simulation.fish.body.angle

        if not self.simulation.sand_grain_bodies:
            sand_grain_positions = [
                self.simulation.sand_grain_bodies[i].position
                for i, b in enumerate(self.simulation.sand_grain_bodies)
            ]
            sand_grain_positions = [[i[0], i[1]] for i in sand_grain_positions]
        else:
            sand_grain_positions = [[10000, 10000]]

        if self.simulation.prey_bodies:
            # TODO: Note hacky fix which may want to clean up later.
            prey_positions = [
                prey.position for prey in self.simulation.prey_bodies
            ]
            prey_positions = [[i[0], i[1]] for i in prey_positions]
            while True:
                if len(prey_positions) < self.last_position_dim:
                    prey_positions = np.append(prey_positions,
                                               [[10000, 10000]],
                                               axis=0)
                else:
                    break

            self.last_position_dim = len(prey_positions)

        else:
            prey_positions = np.array([[10000, 10000]])

        if self.simulation.predator_body is not None:
            predator_position = self.simulation.predator_body.position
            predator_position = np.array(
                [predator_position[0], predator_position[1]])
        else:
            predator_position = np.array([10000, 10000])

        if self.simulation.vegetation_bodies is not None:
            vegetation_positions = [
                self.simulation.vegetation_bodies[i].position
                for i, b in enumerate(self.simulation.vegetation_bodies)
            ]
            vegetation_positions = [[i[0], i[1]] for i in vegetation_positions]
        else:
            vegetation_positions = [[10000, 10000]]

        if not self.learning_params["extra_rnn"]:
            rnn2_state = [0.0]

        # Saving step data
        possible_data_to_save = self.package_output_data(
            o1,
            o2,
            chosen_a,
            sa,
            updated_rnn_state,
            rnn2_state,
            self.simulation.fish.body.position,
            self.simulation.prey_consumed_this_step,
            self.simulation.predator_body,
            conv1l,
            conv2l,
            conv3l,
            conv4l,
            conv1r,
            conv2r,
            conv3r,
            conv4r,
            prey_positions,
            predator_position,
            sand_grain_positions,
            vegetation_positions,
            fish_angle,
        )
        for key in self.assay_output_data_format:
            self.output_data[key].append(possible_data_to_save[key])
        self.output_data["step"].append(self.step_number)

        return o, chosen_a, given_reward, internal_state, o1, d, updated_rnn_state

    def save_hdf5_data(self, assay):
        if assay["save frames"]:
            make_gif(
                self.frame_buffer,
                f"{self.data_save_location}/{self.assay_configuration_id}-{assay['assay id']}.gif",
                duration=len(self.frame_buffer) *
                self.learning_params['time_per_step'],
                true_image=True)
        self.frame_buffer = []

        # absolute_path = '/home/sam/PycharmProjects/SimFish/Assay-Output/new_differential_prey_ref-3' + f'/{self.assay_configuration_id}.h5'
        # hdf5_file = h5py.File(absolute_path, "a")
        hdf5_file = h5py.File(
            f"{self.data_save_location}/{self.assay_configuration_id}.h5", "a")

        try:
            assay_group = hdf5_file.create_group(assay['assay id'])
        except ValueError:
            assay_group = hdf5_file.get(assay['assay id'])

        if "prey_positions" in self.assay_output_data_format.keys():
            self.output_data["prey_positions"] = np.stack(
                self.output_data["prey_positions"])

        for key in self.output_data:
            try:
                # print(self.output_data[key])
                assay_group.create_dataset(
                    key, data=np.array(
                        self.output_data[key]))  # TODO: Compress data.
            except RuntimeError:
                del assay_group[key]
                assay_group.create_dataset(
                    key, data=np.array(
                        self.output_data[key]))  # TODO: Compress data.
        hdf5_file.close()

    def save_episode_data(self):
        self.episode_summary_data = {
            "Prey Caught": self.simulation.prey_caught,
            "Predators Avoided": self.simulation.predators_avoided,
            "Sand Grains Bumped": self.simulation.sand_grains_bumped,
            "Steps Near Vegetation": self.simulation.steps_near_vegetation
        }
        with open(
                f"{self.data_save_location}/{self.assay_configuration_id}-summary_data.json",
                "w") as output_file:
            json.dump(self.episode_summary_data, output_file)
        self.episode_summary_data = None

    def save_stimuli_data(self, assay):
        with open(
                f"{self.data_save_location}/{self.assay_configuration_id}-{assay['assay id']}-stimuli_data.json",
                "w") as output_file:
            json.dump(self.stimuli_data, output_file)
        self.stimuli_data = []

    def save_metadata(self):
        self.metadata["Assay Date"] = datetime.now().strftime(
            "%d/%m/%Y %H:%M:%S")
        with open(
                f"{self.data_save_location}/{self.assay_configuration_id}.json",
                "w") as output_file:
            json.dump(self.metadata, output_file)

    def package_output_data(self, observation, rev_observation, action,
                            advantage_stream, rnn_state, rnn2_state, position,
                            prey_consumed, predator_body, conv1l, conv2l,
                            conv3l, conv4l, conv1r, conv2r, conv3r, conv4r,
                            prey_positions, predator_position,
                            sand_grain_positions, vegetation_positions,
                            fish_angle):
        """

        :param action:
        :param advantage_stream:
        :param rnn_state:
        :param position:
        :param prey_consumed:
        :param predator_body: A boolean to say whether consumed this step.
        :param conv1l:
        :param conv2l:
        :param conv3l:
        :param conv4l:
        :param conv1r:
        :param conv2r:
        :param conv3r:
        :param conv4r:
        :param prey_positions:
        :param predator_position:
        :param sand_grain_positions:
        :param vegetation_positions:
        :return:
        """
        # Make output data JSON serializable
        action = int(action)
        advantage_stream = advantage_stream.tolist()
        rnn_state = rnn_state.c.tolist()
        position = list(position)
        # observation = observation.tolist()

        data = {
            "behavioural choice": action,
            "rnn state": rnn_state,
            "rnn 2 state": rnn2_state,
            "advantage stream": advantage_stream,
            "position": position,
            "observation": observation,
            "rev_observation": rev_observation,
            "left_conv_1": conv1l,
            "left_conv_2": conv2l,
            "left_conv_3": conv3l,
            "left_conv_4": conv4l,
            "right_conv_1": conv1r,
            "right_conv_2": conv2r,
            "right_conv_3": conv3r,
            "right_conv_4": conv4r,
            "prey_positions": prey_positions,
            "predator_position": predator_position,
            "sand_grain_positions": sand_grain_positions,
            "vegetation_positions": vegetation_positions,
            "fish_angle": fish_angle,
            "hunger": self.simulation.fish.hungry,
            "stress": self.simulation.fish.stress,
        }

        if prey_consumed:
            data["consumed"] = 1
        else:
            data["consumed"] = 0
        if predator_body is not None:
            data["predator"] = 1
        else:
            data["predator"] = 0

        stimuli = self.simulation.stimuli_information
        to_save = {}
        for stimulus in stimuli.keys():
            if stimuli[stimulus]:
                to_save[stimulus] = stimuli[stimulus]

        if to_save:
            self.stimuli_data.append(to_save)

        return data

    def make_recordings(self, available_data):
        """No longer used - saves data in JSON"""

        step_data = {
            i: available_data[i]
            for i in self.assay_output_data_format
        }
        for d_type in step_data:
            self.assay_output_data[d_type].append(available_data[d_type])
        step_data["step"] = self.step_number
        self.assay_output_data.append(step_data)

    def save_assay_results(self, assay):
        """No longer used - saves data in JSON"""
        # Saves all the information from the assays in JSON format.
        if assay["save frames"]:
            make_gif(self.frame_buffer,
                     f"{self.data_save_location}/{assay['assay id']}.gif",
                     duration=len(self.frame_buffer) *
                     self.learning_params['time_per_step'],
                     true_image=True)

        self.frame_buffer = []
        with open(f"{self.data_save_location}/{assay['assay id']}.json",
                  "w") as output_file:
            json.dump(self.assay_output_data, output_file)
示例#7
0
            "step": 60,
            "position": [100, 300]
        },
    ]
}

dirname = os.path.dirname(__file__)
file_path = os.path.join(dirname,
                         f"Configurations/Assay-Configs/{arg}_env.json")

with open(file_path, 'r') as f:
    env = json.load(f)

# sim_state = ProjectionEnvironment(env, stimuli, tethered=True, draw_screen=True)
sim_state = NaturalisticEnvironment(env,
                                    realistic_bouts=True,
                                    draw_screen=True)

q = False
d = False
sim_state.reset()
while not q:
    # action = None
    key = input()
    action_input = int(key)

    previous_position = sim_state.fish.body.position

    if action_input < 10:
        s, r, internal, d, fb = sim_state.simulation_step(action_input)
示例#8
0
    def __init__(self, model_name, trial_number, model_exists, tethered,
                 scaffold_name, episode_transitions, total_configurations,
                 conditional_transitions, e, total_steps, episode_number,
                 monitor_gpu, realistic_bouts, memory_fraction, using_gpu):
        """
        An instance of TrainingService handles the training of the DQN within a specified environment, according to
        specified parameters.
        :param model_name: The name of the model, usually to match the naming of the env configuration files.
        :param trial_number: The index of the trial, so that agents trained under the same configuration may be
        distinguished in their output files.
        """

        # Names and directories
        self.trial_id = f"{model_name}-{trial_number}"
        self.output_location = f"./Training-Output/{model_name}-{trial_number}"

        # Configurations
        self.scaffold_name = scaffold_name
        self.total_configurations = total_configurations
        self.episode_transitions = episode_transitions
        self.conditional_transitions = conditional_transitions
        self.tethered = tethered
        self.configuration_index = 1
        self.switched_configuration = False
        self.params, self.env = self.load_configuration_files()

        # Basic Parameters
        self.load_model = model_exists
        self.monitor_gpu = monitor_gpu
        self.using_gpu = using_gpu
        self.realistic_bouts = realistic_bouts
        self.memory_fraction = memory_fraction

        # Maintain variables
        if e is not None:
            self.e = e
        else:
            self.e = self.params["startE"]
        if episode_number is not None:
            self.episode_number = episode_number + 1
        else:
            self.episode_number = 0

        if total_steps is not None:
            self.total_steps = total_steps
        else:
            self.total_steps = 0

        # Network and Training Parameters
        self.saver = None
        self.writer = None
        self.main_QN, self.target_QN = None, None
        self.init = None
        self.trainables = None
        self.target_ops = None
        self.sess = None
        self.step_drop = (self.params['startE'] -
                          self.params['endE']) / self.params['anneling_steps']
        self.pre_train_steps = self.total_steps + self.params["pre_train_steps"]

        # Simulation
        self.simulation = NaturalisticEnvironment(self.env, realistic_bouts)
        self.realistic_bouts = realistic_bouts
        self.save_frames = False
        self.switched_configuration = True

        # Data
        self.training_buffer = ExperienceBuffer(
            output_location=self.output_location,
            buffer_size=self.params["exp_buffer_size"])
        self.frame_buffer = []
        self.training_times = []
        self.reward_list = []

        self.last_episodes_prey_caught = []
        self.last_episodes_predators_avoided = []
        self.last_episodes_sand_grains_bumped = []
示例#9
0
class TrainingService:
    def __init__(self, model_name, trial_number, model_exists, tethered,
                 scaffold_name, episode_transitions, total_configurations,
                 conditional_transitions, e, total_steps, episode_number,
                 monitor_gpu, realistic_bouts, memory_fraction, using_gpu):
        """
        An instance of TrainingService handles the training of the DQN within a specified environment, according to
        specified parameters.
        :param model_name: The name of the model, usually to match the naming of the env configuration files.
        :param trial_number: The index of the trial, so that agents trained under the same configuration may be
        distinguished in their output files.
        """

        # Names and directories
        self.trial_id = f"{model_name}-{trial_number}"
        self.output_location = f"./Training-Output/{model_name}-{trial_number}"

        # Configurations
        self.scaffold_name = scaffold_name
        self.total_configurations = total_configurations
        self.episode_transitions = episode_transitions
        self.conditional_transitions = conditional_transitions
        self.tethered = tethered
        self.configuration_index = 1
        self.switched_configuration = False
        self.params, self.env = self.load_configuration_files()

        # Basic Parameters
        self.load_model = model_exists
        self.monitor_gpu = monitor_gpu
        self.using_gpu = using_gpu
        self.realistic_bouts = realistic_bouts
        self.memory_fraction = memory_fraction

        # Maintain variables
        if e is not None:
            self.e = e
        else:
            self.e = self.params["startE"]
        if episode_number is not None:
            self.episode_number = episode_number + 1
        else:
            self.episode_number = 0

        if total_steps is not None:
            self.total_steps = total_steps
        else:
            self.total_steps = 0

        # Network and Training Parameters
        self.saver = None
        self.writer = None
        self.main_QN, self.target_QN = None, None
        self.init = None
        self.trainables = None
        self.target_ops = None
        self.sess = None
        self.step_drop = (self.params['startE'] -
                          self.params['endE']) / self.params['anneling_steps']
        self.pre_train_steps = self.total_steps + self.params["pre_train_steps"]

        # Simulation
        self.simulation = NaturalisticEnvironment(self.env, realistic_bouts)
        self.realistic_bouts = realistic_bouts
        self.save_frames = False
        self.switched_configuration = True

        # Data
        self.training_buffer = ExperienceBuffer(
            output_location=self.output_location,
            buffer_size=self.params["exp_buffer_size"])
        self.frame_buffer = []
        self.training_times = []
        self.reward_list = []

        self.last_episodes_prey_caught = []
        self.last_episodes_predators_avoided = []
        self.last_episodes_sand_grains_bumped = []

    def load_configuration_files(self):
        """
        Called by create_trials method, should return the learning and environment configurations in JSON format.
        :return:
        """
        print("Loading configuration...")
        configuration_location = f"./Configurations/{self.scaffold_name}/{str(self.configuration_index)}"
        with open(f"{configuration_location}_learning.json", 'r') as f:
            params = json.load(f)
        with open(f"{configuration_location}_env.json", 'r') as f:
            env = json.load(f)
        return params, env

    def run(self):
        """Run the simulation, either loading a checkpoint if there or starting from scratch. If loading, uses the
        previous checkpoint to set the episode number."""

        print("Running simulation")

        if self.using_gpu:
            # options = tf.GPUOptions(per_process_gpu_memory_fraction=self.memory_fraction)
            # config = tf.ConfigProto(gpu_options=options)
            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
        else:
            config = None

        if config:
            with tf.Session(config=config) as self.sess:
                self._run()
        else:
            with tf.Session() as self.sess:
                self._run()

    def _run(self):
        self.main_QN, self.target_QN = self.create_networks()
        self.saver = tf.train.Saver(max_to_keep=5)
        self.init = tf.global_variables_initializer()
        self.trainables = tf.trainable_variables()
        self.target_ops = update_target_graph(self.trainables,
                                              self.params['tau'])
        if self.load_model:
            print(f"Attempting to load model at {self.output_location}")
            checkpoint = tf.train.get_checkpoint_state(self.output_location)
            if hasattr(checkpoint, "model_checkpoint_path"):
                self.saver.restore(self.sess, checkpoint.model_checkpoint_path)
                print("Loading successful")

            else:
                print("No saved checkpoints found, starting from scratch.")
                self.sess.run(self.init)

            # if self.training_buffer.check_saved():  TODO: Consider adding training buffer
            #     print("Loading previous training buffer")
            #     self.training_buffer.load()
            # else:
            #     print("No existing training buffer")
        else:
            print("First attempt at running model. Starting from scratch.")
            self.sess.run(self.init)

        update_target(
            self.target_ops, self.sess
        )  # Set the target network to be equal to the primary network.
        self.writer = tf.summary.FileWriter(f"{self.output_location}/logs/",
                                            tf.get_default_graph())

        for e_number in range(self.episode_number,
                              self.params["num_episodes"]):
            self.episode_number = e_number
            if self.configuration_index < self.total_configurations:
                self.check_update_configuration()
            self.episode_loop()

    def switch_configuration(self, next_point):
        self.configuration_index = int(next_point)
        self.switched_configuration = True
        print(
            f"{self.trial_id}: Changing configuration to configuration {self.configuration_index}"
        )
        self.params, self.env = self.load_configuration_files()
        self.simulation = NaturalisticEnvironment(self.env,
                                                  self.realistic_bouts)

    def check_update_configuration(self):
        # TODO: Will want to tidy this up later.
        next_point = str(self.configuration_index + 1)
        episode_transition_points = self.episode_transitions.keys()

        if next_point in episode_transition_points:
            if self.episode_number > self.episode_transitions[next_point]:
                self.switch_configuration(next_point)
                return

        if len(self.last_episodes_prey_caught) >= 20:
            prey_conditional_transition_points = self.conditional_transitions[
                "Prey Caught"].keys()
            predators_conditional_transition_points = self.conditional_transitions[
                "Predators Avoided"].keys()
            grains_bumped_conditional_transfer_points = self.conditional_transitions[
                "Sand Grains Bumped"].keys()

            if next_point in predators_conditional_transition_points:
                if np.mean(self.last_episodes_predators_avoided
                           ) > self.conditional_transitions[
                               "Predators Avoided"][next_point]:
                    self.switch_configuration(next_point)
                    return

            if next_point in prey_conditional_transition_points:
                if np.mean(
                        self.last_episodes_prey_caught
                ) > self.conditional_transitions["Prey Caught"][next_point]:
                    self.switch_configuration(next_point)
                    return

            if next_point in grains_bumped_conditional_transfer_points:
                if np.mean(self.last_episodes_sand_grains_bumped
                           ) > self.conditional_transitions[
                               "Sand Grains Bumped"][next_point]:
                    self.switch_configuration(next_point)
                    return
        self.switched_configuration = False

    def create_networks(self):
        """
        Create the main and target Q networks, according to the configuration parameters.
        :return: The main network and the target network graphs.
        """
        print("Creating networks...")
        internal_states = sum([
            1 for x in [self.env['hunger'], self.env['stress']] if x is True
        ]) + 1
        cell = tf.nn.rnn_cell.LSTMCell(num_units=self.params['rnn_dim'],
                                       state_is_tuple=True)
        cell_t = tf.nn.rnn_cell.LSTMCell(num_units=self.params['rnn_dim'],
                                         state_is_tuple=True)
        main_QN = QNetwork(self.simulation,
                           self.params['rnn_dim'],
                           cell,
                           'main',
                           self.params['num_actions'],
                           internal_states=internal_states,
                           learning_rate=self.params['learning_rate'],
                           extra_layer=self.params['extra_rnn'])
        target_QN = QNetwork(self.simulation,
                             self.params['rnn_dim'],
                             cell_t,
                             'target',
                             self.params['num_actions'],
                             internal_states=internal_states,
                             learning_rate=self.params['learning_rate'],
                             extra_layer=self.params['extra_rnn'])
        return main_QN, target_QN

    def episode_loop(self):
        """
        Loops over an episode, which involves initialisation of the environment and RNN state, then iteration over the
        steps in the episode. The relevant values are then saved to the experience buffer.
        """
        t0 = time()
        episode_buffer = []

        rnn_state = (np.zeros([1, self.main_QN.rnn_dim]),
                     np.zeros([1, self.main_QN.rnn_dim])
                     )  # Reset RNN hidden state
        self.simulation.reset()
        sa = np.zeros((1, 128))  # Placeholder for the state advantage stream.
        sv = np.zeros((1, 128))  # Placeholder for the state value stream

        # Take the first simulation step, with a capture action. Assigns observation, reward, internal state, done, and
        o, r, internal_state, d, self.frame_buffer = self.simulation.simulation_step(
            action=3,
            frame_buffer=self.frame_buffer,
            save_frames=self.save_frames,
            activations=(sa, ))

        # For benchmarking each episode.
        all_actions = []
        total_episode_reward = 0  # Total reward over episode

        step_number = 0  # To allow exit after maximum steps.
        a = 0  # Initialise action for episode.
        while step_number < self.params["max_epLength"]:
            step_number += 1
            o, a, r, internal_state, o1, d, rnn_state = self.step_loop(
                o=o, internal_state=internal_state, a=a, rnn_state=rnn_state)
            all_actions.append(a)
            episode_buffer.append(
                np.reshape(np.array([o, a, r, internal_state, o1, d]), [1, 6]))
            total_episode_reward += r
            o = o1
            if self.total_steps > self.pre_train_steps:
                if self.e > self.params['endE']:
                    self.e -= self.step_drop
                if self.total_steps % (self.params['update_freq']) == 0:
                    self.train_networks()
            if d:
                break
        # Add the episode to the experience buffer
        self.save_episode(
            episode_start_t=t0,
            all_actions=all_actions,
            total_episode_reward=total_episode_reward,
            episode_buffer=episode_buffer,
            prey_caught=self.simulation.prey_caught,
            predators_avoided=self.simulation.predators_avoided,
            sand_grains_bumped=self.simulation.sand_grains_bumped,
            steps_near_vegetation=self.simulation.steps_near_vegetation)
        # Print saved metrics
        # print(f"Total training time: {sum(self.training_times)}")
        # print(f"Total reward: {sum(self.reward_list)}")

    def save_episode(self, episode_start_t, all_actions, total_episode_reward,
                     episode_buffer, prey_caught, predators_avoided,
                     sand_grains_bumped, steps_near_vegetation):
        """
        Saves the episode the the experience buffer. Also creates a gif if at interval.
        :param episode_start_t: The time at the start of the episode, used to calculate the time the episode took.
        :param all_actions: The array of all the actions taken during the episode.
        :param total_episode_reward: The total reward of the episode.
        :param episode_buffer: A buffer containing all the state transitions, actions and associated rewards yielded by
        the environment.
        :return:
        """

        print(
            f"{self.trial_id} - episode {str(self.episode_number)}: num steps = {str(self.simulation.num_steps)}",
            flush=True)

        # # Log the average training time for episodes (when not saved)
        # if not self.save_frames:
        #     self.training_times.append(time() - episode_start_t)
        #     print(np.mean(self.training_times))

        # Keep recent predators caught.
        self.last_episodes_prey_caught.append(prey_caught)
        self.last_episodes_predators_avoided.append(predators_avoided)
        self.last_episodes_sand_grains_bumped.append(sand_grains_bumped)
        if len(self.last_episodes_predators_avoided) > 20:
            self.last_episodes_prey_caught.pop(0)
            self.last_episodes_predators_avoided.pop(0)
            self.last_episodes_sand_grains_bumped.pop(0)

        # Add Summary to Logs
        episode_summary = tf.Summary(value=[
            tf.Summary.Value(tag="episode reward",
                             simple_value=total_episode_reward)
        ])
        self.writer.add_summary(episode_summary, self.total_steps)

        # Raw logs
        prey_caught_summary = tf.Summary(value=[
            tf.Summary.Value(tag="prey caught", simple_value=prey_caught)
        ])
        self.writer.add_summary(prey_caught_summary, self.episode_number)

        predators_avoided_summary = tf.Summary(value=[
            tf.Summary.Value(tag="predators avoided",
                             simple_value=predators_avoided)
        ])
        self.writer.add_summary(predators_avoided_summary, self.episode_number)

        sand_grains_bumped_summary = tf.Summary(value=[
            tf.Summary.Value(tag="attempted sand grain captures",
                             simple_value=sand_grains_bumped)
        ])
        self.writer.add_summary(sand_grains_bumped_summary,
                                self.episode_number)

        steps_near_vegetation_summary = tf.Summary(value=[
            tf.Summary.Value(tag="steps near vegetation",
                             simple_value=steps_near_vegetation)
        ])
        self.writer.add_summary(steps_near_vegetation_summary,
                                self.episode_number)

        # Normalised Logs
        if self.env["prey_num"] != 0:
            fraction_prey_caught = prey_caught / self.env["prey_num"]
            prey_caught_summary = tf.Summary(value=[
                tf.Summary.Value(tag="prey capture index (fraction caught)",
                                 simple_value=fraction_prey_caught)
            ])
            self.writer.add_summary(prey_caught_summary, self.episode_number)

        if self.env["probability_of_predator"] != 0:
            predator_avoided_index = predators_avoided / self.env[
                "probability_of_predator"]
            predators_avoided_summary = tf.Summary(value=[
                tf.Summary.Value(
                    tag="predator avoidance index (avoided/p_pred)",
                    simple_value=predator_avoided_index)
            ])
            self.writer.add_summary(predators_avoided_summary,
                                    self.episode_number)

        if self.env["sand_grain_num"] != 0:
            sand_grain_capture_index = sand_grains_bumped / self.env[
                "sand_grain_num"]
            sand_grains_bumped_summary = tf.Summary(value=[
                tf.Summary.Value(
                    tag="sand grain capture index (fraction attempted caught)",
                    simple_value=sand_grain_capture_index)
            ])
            self.writer.add_summary(sand_grains_bumped_summary,
                                    self.episode_number)

        if self.env["vegetation_num"] != 0:
            vegetation_index = (
                steps_near_vegetation /
                self.simulation.num_steps) / self.env["vegetation_num"]
            use_of_vegetation_summary = tf.Summary(value=[
                tf.Summary.Value(
                    tag=
                    "use of vegetation index (fraction_steps/vegetation_num",
                    simple_value=vegetation_index)
            ])
            self.writer.add_summary(use_of_vegetation_summary,
                                    self.episode_number)

        if self.switched_configuration:
            configuration_summary = tf.Summary(value=[
                tf.Summary.Value(tag="Configuration change",
                                 simple_value=self.configuration_index)
            ])
            self.writer.add_summary(configuration_summary, self.episode_number)

        for act in range(self.params['num_actions']):
            action_freq = np.sum(
                np.array(all_actions) == act) / len(all_actions)
            a_freq = tf.Summary(value=[
                tf.Summary.Value(tag="action " + str(act),
                                 simple_value=action_freq)
            ])
            self.writer.add_summary(a_freq, self.total_steps)

        # Save the parameters to be carried over.
        output_data = {
            "epsilon": self.e,
            "episode_number": self.episode_number,
            "total_steps": self.total_steps
        }
        with open(f"{self.output_location}/saved_parameters.json",
                  "w") as file:
            json.dump(output_data, file)

        buffer_array = np.array(episode_buffer)
        episode_buffer = list(zip(buffer_array))
        self.training_buffer.add(episode_buffer)
        self.reward_list.append(total_episode_reward)
        # Periodically save the model.
        if self.episode_number % self.params[
                'summaryLength'] == 0 and self.episode_number != 0:
            # print(f"mean time: {np.mean(self.training_times)}")

            # Save training buffer  TODO: Consider adding training buffer in future
            # print(f"Saving training buffer for {self.trial_id}")
            # self.training_buffer.save()

            # Save the model
            self.saver.save(
                self.sess,
                f"{self.output_location}/model-{str(self.episode_number)}.cptk"
            )
            print("Saved Model")

            # Create the GIF
            make_gif(
                self.frame_buffer,
                f"{self.output_location}/episodes/episode-{str(self.episode_number)}.gif",
                duration=len(self.frame_buffer) * self.params['time_per_step'],
                true_image=True)
            self.frame_buffer = []
            self.save_frames = False

        if (self.episode_number + 1) % self.params['summaryLength'] == 0:
            print('starting to save frames', flush=True)
            self.save_frames = True
        if self.monitor_gpu:
            print(f"GPU usage {os.system('gpustat -cp')}")

    def step_loop(self, o, internal_state, a, rnn_state):
        """
        Runs a step, choosing an action given an initial condition using the network/randomly, and running this in the
        environment.

        :param
        session: The TF session.
        internal_state: The internal state of the agent - whether it is in light, and whether it is hungry.
        a: The previous chosen action.
        rnn_state: The state inside the RNN.

        :return:
        s: The environment state.
        chosen_a: The action chosen randomly/by the network.
        given_reward: The reward returned.
        internal_state: The internal state of the agent - whether it is in light, and whether it is hungry.
        s1: The subsequent environment state
        d: Boolean indicating agent death.
        updated_rnn_state: The updated RNN state
        """

        # Generate actions and corresponding steps.
        if np.random.rand(
                1) < self.e or self.total_steps < self.pre_train_steps:
            [updated_rnn_state, sa, sv] = self.sess.run(
                [
                    self.main_QN.rnn_state, self.main_QN.streamA,
                    self.main_QN.streamV
                ],
                feed_dict={
                    self.main_QN.observation: o,
                    self.main_QN.internal_state: internal_state,
                    self.main_QN.prev_actions: [a],
                    self.main_QN.trainLength: 1,
                    self.main_QN.state_in: rnn_state,
                    self.main_QN.batch_size: 1,
                    self.main_QN.exp_keep: 1.0
                })
            chosen_a = np.random.randint(0, self.params['num_actions'])
        else:
            chosen_a, updated_rnn_state, sa, sv = self.sess.run(
                [
                    self.main_QN.predict, self.main_QN.rnn_state,
                    self.main_QN.streamA, self.main_QN.streamV
                ],
                feed_dict={
                    self.main_QN.observation: o,
                    self.main_QN.internal_state: internal_state,
                    self.main_QN.prev_actions: [a],
                    self.main_QN.trainLength: 1,
                    self.main_QN.state_in: rnn_state,
                    self.main_QN.batch_size: 1,
                    self.main_QN.exp_keep: 1.0
                })
            chosen_a = chosen_a[0]

        # Simulation step
        o1, given_reward, internal_state, d, self.frame_buffer = self.simulation.simulation_step(
            action=chosen_a,
            frame_buffer=self.frame_buffer,
            save_frames=self.save_frames,
            activations=(sa, ))
        self.total_steps += 1
        return o, chosen_a, given_reward, internal_state, o1, d, updated_rnn_state

    def train_networks(self):
        """
        Trains the two networks, copying over the target network
        :return:
        """
        update_target(self.target_ops, self.sess)
        # Reset the recurrent layer's hidden state
        state_train = (np.zeros([
            self.params['batch_size'], self.main_QN.rnn_dim
        ]), np.zeros([self.params['batch_size'], self.main_QN.rnn_dim]))

        # Get a random batch of experiences: ndarray 1024x6, with the six columns containing o, a, r, i_s, o1, d
        train_batch = self.training_buffer.sample(self.params['batch_size'],
                                                  self.params['trace_length'])

        # Below we perform the Double-DQN update to the target Q-values
        Q1 = self.sess.run(self.main_QN.predict,
                           feed_dict={
                               self.main_QN.observation:
                               np.vstack(train_batch[:, 4]),
                               self.main_QN.prev_actions:
                               np.hstack(([0], train_batch[:-1, 1])),
                               self.main_QN.trainLength:
                               self.params['trace_length'],
                               self.main_QN.internal_state:
                               np.vstack(train_batch[:, 3]),
                               self.main_QN.state_in:
                               state_train,
                               self.main_QN.batch_size:
                               self.params['batch_size'],
                               self.main_QN.exp_keep:
                               1.0
                           })

        Q2 = self.sess.run(self.target_QN.Q_out,
                           feed_dict={
                               self.target_QN.observation:
                               np.vstack(train_batch[:, 4]),
                               self.target_QN.prev_actions:
                               np.hstack(([0], train_batch[:-1, 1])),
                               self.target_QN.trainLength:
                               self.params['trace_length'],
                               self.target_QN.internal_state:
                               np.vstack(train_batch[:, 3]),
                               self.target_QN.state_in:
                               state_train,
                               self.target_QN.batch_size:
                               self.params['batch_size'],
                               self.target_QN.exp_keep:
                               1.0
                           })

        end_multiplier = -(train_batch[:, 5] - 1)

        double_Q = Q2[range(self.params['batch_size'] *
                            self.params['trace_length']), Q1]
        target_Q = train_batch[:, 2] + (self.params['y'] * double_Q *
                                        end_multiplier)
        # Update the network with our target values.
        self.sess.run(self.main_QN.updateModel,
                      feed_dict={
                          self.main_QN.observation:
                          np.vstack(train_batch[:, 0]),
                          self.main_QN.targetQ:
                          target_Q,
                          self.main_QN.actions:
                          train_batch[:, 1],
                          self.main_QN.internal_state:
                          np.vstack(train_batch[:, 3]),
                          self.main_QN.prev_actions:
                          np.hstack(([3], train_batch[:-1, 1])),
                          self.main_QN.trainLength:
                          self.params['trace_length'],
                          self.main_QN.state_in:
                          state_train,
                          self.main_QN.batch_size:
                          self.params['batch_size'],
                          self.main_QN.exp_keep:
                          1.0
                      })