Ejemplo n.º 1
0
class MultiAgentGraphManager(object):
    """
    A simple multi-agent graph manager and a single environment which is interacted with.
    """
    def __init__(
            self,
            agents_params: List[AgentParameters],
            env_params: EnvironmentParameters,
            schedule_params: ScheduleParameters,
            vis_params: VisualizationParameters = VisualizationParameters(),
            preset_validation_params:
        PresetValidationParameters = PresetValidationParameters(),
            done_condition=any):
        self.done_condition = done_condition
        self.sess = {agent_params.name: None for agent_params in agents_params}
        self.level_managers = []  # type: List[MultiAgentLevelManager]
        self.top_level_manager = None
        self.environments = []
        self.set_schedule_params(schedule_params)
        self.visualization_parameters = vis_params
        self.name = 'multi_agent_graph'
        self.task_parameters = None
        self._phase = self.phase = RunPhase.UNDEFINED
        self.preset_validation_params = preset_validation_params
        self.reset_required = False
        self.num_checkpoints_to_keep = 4  # TODO: make this a parameter

        # timers
        self.graph_creation_time = None
        self.last_checkpoint_saving_time = time.time()

        # counters
        self.total_steps_counters = {
            RunPhase.HEATUP: TotalStepsCounter(),
            RunPhase.TRAIN: TotalStepsCounter(),
            RunPhase.TEST: TotalStepsCounter()
        }
        self.checkpoint_id = 0

        self.checkpoint_saver = {
            agent_params.name: None
            for agent_params in agents_params
        }
        self.checkpoint_state_updater = None
        self.graph_logger = Logger()
        self.data_store = None
        self.is_batch_rl = False
        self.time_metric = TimeTypes.EpisodeNumber

        self.env_params = env_params
        self.agents_params = agents_params
        self.agent_params = agents_params[0]  # ...(find a better way)...

        for agent_index, agent_params in enumerate(agents_params):
            if len(agents_params) == 1:
                agent_params.name = "agent"
            else:
                agent_params.name = "agent_{}".format(agent_index)
            agent_params.visualization = copy.copy(vis_params)
            if agent_params.input_filter is None:
                agent_params.input_filter = copy.copy(
                    env_params.default_input_filter())
            if agent_params.output_filter is None:
                agent_params.output_filter = copy.copy(
                    env_params.default_output_filter())

    def create_graph(self,
                     task_parameters=TaskParameters(),
                     stop_physics=None,
                     start_physics=None,
                     empty_service_call=None):
        self.graph_creation_time = time.time()
        self.task_parameters = task_parameters

        if isinstance(task_parameters, DistributedTaskParameters):
            screen.log_title(
                "Creating graph - name: {} task id: {} type: {}".format(
                    self.__class__.__name__, task_parameters.task_index,
                    task_parameters.job_type))
        else:
            screen.log_title("Creating graph - name: {}".format(
                self.__class__.__name__))

        # "hide" the gpu if necessary
        if task_parameters.use_cpu:
            set_cpu()

        # create a target server for the worker and a device
        if isinstance(task_parameters, DistributedTaskParameters):
            task_parameters.worker_target, task_parameters.device = \
                self.create_worker_or_parameters_server(task_parameters=task_parameters)
        # If necessary start the physics and then stop it after agent creation
        screen.log_title("Start physics before creating graph")
        if start_physics and empty_service_call:
            start_physics(empty_service_call())
        # create the graph modules
        screen.log_title("Create graph")
        self.level_managers, self.environments = self._create_graph(
            task_parameters)
        screen.log_title("Stop physics after creating graph")
        if stop_physics and empty_service_call:
            stop_physics(empty_service_call())
        # set self as the parent of all the level managers
        self.top_level_manager = self.level_managers[0]
        for level_manager in self.level_managers:
            level_manager.parent_graph_manager = self

        # create a session (it needs to be created after all the graph ops were created)
        self.sess = {
            agent_params.name: None
            for agent_params in self.agents_params
        }
        screen.log_title("Creating session")
        self.create_session(task_parameters=task_parameters)
        self._phase = self.phase = RunPhase.UNDEFINED

        self.setup_logger()

        return self

    def _create_graph(
        self, task_parameters: TaskParameters
    ) -> Tuple[List[MultiAgentLevelManager], List[Environment]]:
        # environment loading
        self.env_params.seed = task_parameters.seed
        self.env_params.experiment_path = task_parameters.experiment_path
        env = short_dynamic_import(self.env_params.path)(
            **self.env_params.__dict__,
            visualization_parameters=self.visualization_parameters)

        # agent loading
        agents = OrderedDict()
        for agent_params in self.agents_params:
            agent_params.task_parameters = copy.copy(task_parameters)
            agent = short_dynamic_import(agent_params.path)(agent_params)
            agents[agent_params.name] = agent
            screen.log_title("Created agent: {}".format(agent_params.name))
            if hasattr(self, 'memory_backend_params') and \
                    self.memory_backend_params.run_type == str(RunType.ROLLOUT_WORKER):
                agent.memory.memory_backend = deepracer_memory.DeepRacerRolloutBackEnd(
                    self.memory_backend_params,
                    agent_params.algorithm.num_consecutive_playing_steps,
                    agent_params.name)

        # set level manager
        level_manager = MultiAgentLevelManager(
            agents=agents,
            environment=env,
            name="main_level",
            done_condition=self.done_condition)
        return [level_manager], [env]

    @staticmethod
    def _create_worker_or_parameters_server_tf(
            task_parameters: DistributedTaskParameters):
        import tensorflow as tf
        config = tf.ConfigProto()
        config.allow_soft_placement = True  # allow placing ops on cpu if they are not fit for gpu
        config.gpu_options.allow_growth = True  # allow the gpu memory allocated for the worker to grow if needed
        config.gpu_options.per_process_gpu_memory_fraction = 0.2
        config.intra_op_parallelism_threads = 1
        config.inter_op_parallelism_threads = 1

        from rl_coach.architectures.tensorflow_components.distributed_tf_utils import \
            create_and_start_parameters_server, \
            create_cluster_spec, create_worker_server_and_device

        # create cluster spec
        cluster_spec = create_cluster_spec(
            parameters_server=task_parameters.parameters_server_hosts,
            workers=task_parameters.worker_hosts)

        # create and start parameters server (non-returning function) or create a worker and a device setter
        if task_parameters.job_type == "ps":
            create_and_start_parameters_server(cluster_spec=cluster_spec,
                                               config=config)
        elif task_parameters.job_type == "worker":
            return create_worker_server_and_device(
                cluster_spec=cluster_spec,
                task_index=task_parameters.task_index,
                use_cpu=task_parameters.use_cpu,
                config=config)
        else:
            raise ValueError(
                "The job type should be either ps or worker and not {}".format(
                    task_parameters.job_type))

    @staticmethod
    def create_worker_or_parameters_server(
            task_parameters: DistributedTaskParameters):
        if task_parameters.framework_type == Frameworks.tensorflow:
            return GraphManager._create_worker_or_parameters_server_tf(
                task_parameters)
        elif task_parameters.framework_type == Frameworks.mxnet:
            raise NotImplementedError(
                'Distributed training not implemented for MXNet')
        else:
            raise ValueError('Invalid framework {}'.format(
                task_parameters.framework_type))

    def _create_session_tf(self, task_parameters: TaskParameters):
        import tensorflow as tf
        config = tf.ConfigProto()
        config.allow_soft_placement = True  # allow placing ops on cpu if they are not fit for gpu
        config.gpu_options.allow_growth = True  # allow the gpu memory allocated for the worker to grow if needed
        # config.gpu_options.per_process_gpu_memory_fraction = 0.2
        config.intra_op_parallelism_threads = 1
        config.inter_op_parallelism_threads = 1

        if isinstance(task_parameters, DistributedTaskParameters):
            # the distributed tensorflow setting
            from rl_coach.architectures.tensorflow_components.distributed_tf_utils import create_monitored_session
            if hasattr(self.task_parameters, 'checkpoint_restore_path'
                       ) and self.task_parameters.checkpoint_restore_path:
                checkpoint_dir = os.path.join(task_parameters.experiment_path,
                                              'checkpoint')
                if os.path.exists(checkpoint_dir):
                    remove_tree(checkpoint_dir)
                # in the locally distributed case, checkpoints are always restored from a directory (and not from a
                # file)
                copy_tree(task_parameters.checkpoint_restore_path,
                          checkpoint_dir)
            else:
                checkpoint_dir = task_parameters.checkpoint_save_dir

            self.sess = create_monitored_session(
                target=task_parameters.worker_target,
                task_index=task_parameters.task_index,
                checkpoint_dir=checkpoint_dir,
                checkpoint_save_secs=task_parameters.checkpoint_save_secs,
                config=config)
            # set the session for all the modules
            self.set_session(self.sess)
        else:
            # regular session
            print("Creating regular session")
            self.sess = {
                agent_params.name: tf.Session(config=config)
                for agent_params in self.agents_params
            }
            # set the session for all the modules
            self.set_session(self.sess)

        # the TF graph is static, and therefore is saved once - in the beginning of the experiment
        if hasattr(self.task_parameters, 'checkpoint_save_dir'
                   ) and self.task_parameters.checkpoint_save_dir:
            self.save_graph()

    def _create_session_mx(self):
        """
        Call set_session to initialize parameters and construct checkpoint_saver
        """
        self.set_session(sess=None)  # Initialize all modules

    def create_session(self, task_parameters: TaskParameters):
        if task_parameters.framework_type == Frameworks.tensorflow:
            self._create_session_tf(task_parameters)
        elif task_parameters.framework_type == Frameworks.mxnet:
            self._create_session_mx()
        else:
            raise ValueError('Invalid framework {}'.format(
                task_parameters.framework_type))

        # Create parameter saver
        self.checkpoint_saver = {
            agent_params.name: SaverCollection()
            for agent_params in self.agents_params
        }
        for level in self.level_managers:
            for agent_params in self.agents_params:
                self.checkpoint_saver[agent_params.name].update(
                    level.collect_savers(agent_params.name))

        # restore from checkpoint if given
        self.restore_checkpoint()

    def save_graph(self) -> None:
        """
        Save the TF graph to a protobuf description file in the experiment directory
        :return: None
        """
        import tensorflow as tf

        # write graph
        tf.train.write_graph(tf.get_default_graph(),
                             logdir=self.task_parameters.checkpoint_save_dir,
                             name='graphdef.pb',
                             as_text=False)

    def _save_onnx_graph_tf(self) -> None:
        """
        Save the tensorflow graph as an ONNX graph.
        This requires the graph and the weights checkpoint to be stored in the experiment directory.
        It then freezes the graph (merging the graph and weights checkpoint), and converts it to ONNX.
        :return: None
        """
        # collect input and output nodes
        input_nodes = []
        output_nodes = []
        for level in self.level_managers:
            for agent in level.agents.values():
                for network in agent.networks.values():
                    for input_key, input in network.online_network.inputs.items(
                    ):
                        if not input_key.startswith("output_"):
                            input_nodes.append(input.name)
                    for output in network.online_network.outputs:
                        output_nodes.append(output.name)

        from rl_coach.architectures.tensorflow_components.architecture import save_onnx_graph

        save_onnx_graph(input_nodes, output_nodes,
                        self.task_parameters.checkpoint_save_dir)

    def save_onnx_graph(self) -> None:
        """
        Save the graph as an ONNX graph.
        This requires the graph and the weights checkpoint to be stored in the experiment directory.
        It then freezes the graph (merging the graph and weights checkpoint), and converts it to ONNX.
        :return: None
        """
        if self.task_parameters.framework_type == Frameworks.tensorflow:
            self._save_onnx_graph_tf()

    def setup_logger(self) -> None:
        # dump documentation
        logger_prefix = "{graph_name}".format(graph_name=self.name)
        self.graph_logger.set_logger_filenames(
            self.task_parameters.experiment_path,
            logger_prefix=logger_prefix,
            add_timestamp=True,
            task_id=self.task_parameters.task_index)
        if self.visualization_parameters.dump_parameters_documentation:
            self.graph_logger.dump_documentation(str(self))
        [manager.setup_logger() for manager in self.level_managers]

    @property
    def phase(self) -> RunPhase:
        """
        Get the phase of the graph
        :return: the current phase
        """
        return self._phase

    @phase.setter
    def phase(self, val: RunPhase):
        """
        Change the phase of the graph and all the hierarchy levels below it
        :param val: the new phase
        :return: None
        """
        self._phase = val
        for level_manager in self.level_managers:
            level_manager.phase = val
        for environment in self.environments:
            environment.phase = val
            environment._notify_phase(val)

    @property
    def current_step_counter(self) -> TotalStepsCounter:
        return self.total_steps_counters[self.phase]

    @contextlib.contextmanager
    def phase_context(self, phase):
        """
        Create a context which temporarily sets the phase to the provided phase.
        The previous phase is restored afterwards.
        """
        old_phase = self.phase
        self.phase = phase
        yield
        self.phase = old_phase

    def set_session(self, sess) -> None:
        """
        Set the deep learning framework session for all the modules in the graph
        :return: None
        """
        [manager.set_session(sess) for manager in self.level_managers]

    def heatup(self, steps: PlayingStepsType) -> None:
        """
        Perform heatup for several steps, which means taking random actions and storing the results in memory
        :param steps: the number of steps as a tuple of steps time and steps count
        :return: None
        """
        self.verify_graph_was_created()

        if steps.num_steps > 0:
            with self.phase_context(RunPhase.HEATUP):
                screen.log_title("{}: Starting heatup".format(self.name))

                # reset all the levels before starting to heatup
                self.reset_internal_state(force_environment_reset=True)

                # act for at least steps, though don't interrupt an episode
                count_end = self.current_step_counter + steps
                while self.current_step_counter < count_end:
                    self.act(EnvironmentEpisodes(1))

    def handle_episode_ended(self) -> None:
        """
        End an episode and reset all the episodic parameters
        :return: None
        """
        self.current_step_counter[EnvironmentEpisodes] += 1

        [
            environment.handle_episode_ended()
            for environment in self.environments
        ]

    def train(self) -> None:
        """
        Perform several training iterations for all the levels in the hierarchy
        :param steps: number of training iterations to perform
        :return: None
        """
        self.verify_graph_was_created()

        with self.phase_context(RunPhase.TRAIN):
            self.current_step_counter[TrainingSteps] += 1
            [manager.train() for manager in self.level_managers]

    def reset_internal_state(self, force_environment_reset=False) -> None:
        """
        Reset an episode for all the levels
        :param force_environment_reset: force the environment to reset the episode even if it has some conditions that
                                        tell it not to. for example, if ale life is lost, gym will tell the agent that
                                        the episode is finished but won't actually reset the episode if there are more
                                        lives available
        :return: None
        """
        self.verify_graph_was_created()

        self.reset_required = False
        [
            environment.reset_internal_state(force_environment_reset)
            for environment in self.environments
        ]
        [manager.reset_internal_state() for manager in self.level_managers]

    def act(self,
            steps: PlayingStepsType,
            wait_for_full_episodes=False) -> None:
        """
        Do several steps of acting on the environment
        :param wait_for_full_episodes: if set, act for at least `steps`, but make sure that the last episode is complete
        :param steps: the number of steps as a tuple of steps time and steps count
        """
        self.verify_graph_was_created()

        # perform several steps of playing
        count_end = self.current_step_counter + steps
        done = False
        while self.current_step_counter < count_end or (wait_for_full_episodes
                                                        and not done):
            # reset the environment if the previous episode was terminated
            if self.reset_required:
                self.reset_internal_state()

            steps_begin = self.environments[0].total_steps_counter
            done = self.top_level_manager.step(None)
            steps_end = self.environments[0].total_steps_counter

            if done:
                self.handle_episode_ended()
                self.reset_required = True

            self.current_step_counter[EnvironmentSteps] += (steps_end -
                                                            steps_begin)

            # if no steps were made (can happen when no actions are taken while in the TRAIN phase, either in batch RL
            # or in imitation learning), we force end the loop, so that it will not continue forever.
            if (steps_end - steps_begin) == 0:
                break

    def train_and_act(self, steps: StepMethod) -> None:
        """
        Train the agent by doing several acting steps followed by several training steps continually
        :param steps: the number of steps as a tuple of steps time and steps count
        :return: None
        """
        self.verify_graph_was_created()

        # perform several steps of training interleaved with acting
        if steps.num_steps > 0:
            with self.phase_context(RunPhase.TRAIN):
                self.reset_internal_state(force_environment_reset=True)

                count_end = self.current_step_counter + steps
                while self.current_step_counter < count_end:
                    # The actual number of steps being done on the environment
                    # is decided by the agent, though this inner loop always
                    # takes at least one step in the environment (at the GraphManager level).
                    # The agent might also decide to skip acting altogether.
                    # Depending on internal counters and parameters, it doesn't always train or save checkpoints.
                    self.act(EnvironmentSteps(1))
                    self.train()
                    self.occasionally_save_checkpoint()

    def sync(self) -> None:
        """
        Sync the global network parameters to the graph
        :return:
        """
        [manager.sync() for manager in self.level_managers]

    def evaluate(self, steps: PlayingStepsType) -> bool:
        """
        Perform evaluation for several steps
        :param steps: the number of steps as a tuple of steps time and steps count
        :return: bool, True if the target reward and target success has been reached
        """
        self.verify_graph_was_created()

        if steps.num_steps > 0:
            with self.phase_context(RunPhase.TEST):
                # reset all the levels before starting to evaluate
                self.reset_internal_state(force_environment_reset=True)
                self.sync()

                # act for at least `steps`, though don't interrupt an episode
                count_end = self.current_step_counter + steps
                while self.current_step_counter < count_end:
                    self.act(EnvironmentEpisodes(1))
                    self.sync()
        if self.should_stop():
            self.flush_finished()
            screen.success("Reached required success rate. Exiting.")
            return True
        return False

    def improve(self):
        """
        The main loop of the run.
        Defined in the following steps:
        1. Heatup
        2. Repeat:
            2.1. Repeat:
                2.1.1. Act
                2.1.2. Train
                2.1.3. Possibly save checkpoint
            2.2. Evaluate
        :return: None
        """

        self.verify_graph_was_created()

        # initialize the network parameters from the global network
        self.sync()

        # heatup
        self.heatup(self.heatup_steps)

        # improve
        if self.task_parameters.task_index is not None:
            screen.log_title("Starting to improve {} task index {}".format(
                self.name, self.task_parameters.task_index))
        else:
            screen.log_title("Starting to improve {}".format(self.name))

        count_end = self.total_steps_counters[
            RunPhase.TRAIN] + self.improve_steps
        while self.total_steps_counters[RunPhase.TRAIN] < count_end:
            self.train_and_act(self.steps_between_evaluation_periods)
            if self.evaluate(self.evaluation_steps):
                break

    def restore_checkpoint(self):
        self.verify_graph_was_created()

        # TODO: find better way to load checkpoints that were saved with a global network into the online network
        if self.task_parameters.checkpoint_restore_path:
            restored_checkpoint_paths = []
            for agent_params in self.agents_params:
                if len(self.agents_params) == 1:
                    agent_checkpoint_restore_path = self.task_parameters.checkpoint_restore_path
                else:
                    agent_checkpoint_restore_path = os.path.join(
                        self.task_parameters.checkpoint_restore_path,
                        agent_params.name)
                if os.path.isdir(agent_checkpoint_restore_path):
                    # a checkpoint dir
                    if self.task_parameters.framework_type == Frameworks.tensorflow and\
                            'checkpoint' in os.listdir(agent_checkpoint_restore_path):
                        # TODO-fixme checkpointing
                        # MonitoredTrainingSession manages save/restore checkpoints autonomously. Doing so,
                        # it creates it own names for the saved checkpoints, which do not match the "{}_Step-{}.ckpt"
                        # filename pattern. The names used are maintained in a CheckpointState protobuf file named
                        # 'checkpoint'. Using Coach's '.coach_checkpoint' protobuf file, results in an error when trying to
                        # restore the model, as the checkpoint names defined do not match the actual checkpoint names.
                        raise NotImplementedError(
                            'Checkpointing not implemented for TF monitored training session'
                        )
                    else:
                        checkpoint = get_checkpoint_state(
                            agent_checkpoint_restore_path,
                            all_checkpoints=True)

                    if checkpoint is None:
                        raise ValueError(
                            "No checkpoint to restore in: {}".format(
                                agent_checkpoint_restore_path))
                    model_checkpoint_path = checkpoint.model_checkpoint_path
                    checkpoint_restore_dir = self.task_parameters.checkpoint_restore_path
                    restored_checkpoint_paths.append(model_checkpoint_path)

                    # Set the last checkpoint ID - only in the case of the path being a dir
                    chkpt_state_reader = CheckpointStateReader(
                        agent_checkpoint_restore_path,
                        checkpoint_state_optional=False)
                    self.checkpoint_id = chkpt_state_reader.get_latest(
                    ).num + 1
                else:
                    # a checkpoint file
                    if self.task_parameters.framework_type == Frameworks.tensorflow:
                        model_checkpoint_path = agent_checkpoint_restore_path
                        checkpoint_restore_dir = os.path.dirname(
                            model_checkpoint_path)
                        restored_checkpoint_paths.append(model_checkpoint_path)
                    else:
                        raise ValueError(
                            "Currently restoring a checkpoint using the --checkpoint_restore_file argument is"
                            " only supported when with tensorflow.")

                try:
                    self.checkpoint_saver[agent_params.name].restore(
                        self.sess[agent_params.name], model_checkpoint_path)
                except Exception as ex:
                    raise ValueError(
                        "Failed to restore {}'s checkpoint: {}".format(
                            agent_params.name, ex))

                all_checkpoints = sorted(
                    list(set([c.name for c in checkpoint.all_checkpoints
                              ])))  # remove duplicates :-(
                if self.num_checkpoints_to_keep < len(all_checkpoints):
                    checkpoint_to_delete = all_checkpoints[
                        -self.num_checkpoints_to_keep - 1]
                    agent_checkpoint_to_delete = os.path.join(
                        agent_checkpoint_restore_path, checkpoint_to_delete)
                    for file in glob.glob(
                            "{}*".format(agent_checkpoint_to_delete)):
                        os.remove(file)

            [
                manager.restore_checkpoint(checkpoint_restore_dir)
                for manager in self.level_managers
            ]
            [
                manager.post_training_commands()
                for manager in self.level_managers
            ]

            screen.log_dict(OrderedDict([
                ("Restoring from path", restored_checkpoint_path)
                for restored_checkpoint_path in restored_checkpoint_paths
            ]),
                            prefix="Checkpoint")

    def _get_checkpoint_state_tf(self, checkpoint_restore_dir):
        import tensorflow as tf
        return tf.train.get_checkpoint_state(checkpoint_restore_dir)

    def occasionally_save_checkpoint(self):
        # only the chief process saves checkpoints
        if self.task_parameters.checkpoint_save_secs \
                and time.time() - self.last_checkpoint_saving_time >= self.task_parameters.checkpoint_save_secs \
                and (self.task_parameters.task_index == 0  # distributed
                     or self.task_parameters.task_index is None  # single-worker
                     ):
            self.save_checkpoint()

    def save_checkpoint(self):
        # create current session's checkpoint directory
        if self.task_parameters.checkpoint_save_dir is None:
            self.task_parameters.checkpoint_save_dir = os.path.join(
                self.task_parameters.experiment_path, 'checkpoint')

        if not os.path.exists(self.task_parameters.checkpoint_save_dir):
            os.mkdir(self.task_parameters.checkpoint_save_dir
                     )  # Create directory structure

        if self.checkpoint_state_updater is None:
            self.checkpoint_state_updater = CheckpointStateUpdater(
                self.task_parameters.checkpoint_save_dir)

        checkpoint_name = "{}_Step-{}.ckpt".format(
            self.checkpoint_id,
            self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps])

        saved_checkpoint_paths = []
        for agent_params in self.agents_params:
            if len(self.agents_params) == 1:
                agent_checkpoint_save_dir = self.task_parameters.checkpoint_save_dir
            else:
                agent_checkpoint_save_dir = os.path.join(
                    self.task_parameters.checkpoint_save_dir,
                    agent_params.name)
            if not os.path.exists(agent_checkpoint_save_dir):
                os.mkdir(agent_checkpoint_save_dir)

            agent_checkpoint_path = os.path.join(agent_checkpoint_save_dir,
                                                 checkpoint_name)
            if not isinstance(self.task_parameters, DistributedTaskParameters):
                saved_checkpoint_paths.append(
                    self.checkpoint_saver[agent_params.name].save(
                        self.sess[agent_params.name], agent_checkpoint_path))
            else:
                saved_checkpoint_paths.append(agent_checkpoint_path)

            if self.num_checkpoints_to_keep < len(
                    self.checkpoint_state_updater.all_checkpoints):
                checkpoint_to_delete = self.checkpoint_state_updater.all_checkpoints[
                    -self.num_checkpoints_to_keep - 1]
                agent_checkpoint_to_delete = os.path.join(
                    agent_checkpoint_save_dir, checkpoint_to_delete.name)
                for file in glob.glob(
                        "{}*".format(agent_checkpoint_to_delete)):
                    os.remove(file)

        # this is required in order for agents to save additional information like a DND for example
        [
            manager.save_checkpoint(checkpoint_name)
            for manager in self.level_managers
        ]

        # Purge Redis memory after saving the checkpoint as Transitions are no longer needed at this point.
        if hasattr(self, 'memory_backend'):
            self.memory_backend.memory_purge()

        # the ONNX graph will be stored only if checkpoints are stored and the -onnx flag is used
        if self.task_parameters.export_onnx_graph:
            self.save_onnx_graph()

        # write the new checkpoint name to a file to signal this checkpoint has been fully saved
        self.checkpoint_state_updater.update(
            SingleCheckpoint(self.checkpoint_id, checkpoint_name))

        screen.log_dict(OrderedDict([
            ("Saving in path", saved_checkpoint_path)
            for saved_checkpoint_path in saved_checkpoint_paths
        ]),
                        prefix="Checkpoint")

        self.checkpoint_id += 1
        self.last_checkpoint_saving_time = time.time()

        if hasattr(self, 'data_store_params'):
            data_store = self.get_data_store(self.data_store_params)
            data_store.save_to_store()

    def verify_graph_was_created(self):
        """
        Verifies that the graph was already created, and if not, it creates it with the default task parameters
        :return: None
        """
        if self.graph_creation_time is None:
            self.create_graph()

    def __str__(self):
        result = ""
        for key, val in self.__dict__.items():
            params = ""
            if isinstance(val, list) or isinstance(val, dict) or isinstance(
                    val, OrderedDict):
                items = iterable_to_items(val)
                for k, v in items:
                    params += "{}: {}\n".format(k, v)
            else:
                params = val
            result += "{}: \n{}\n".format(key, params)

        return result

    def should_train(self) -> bool:
        return any([manager.should_train() for manager in self.level_managers])

    # TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
    #               an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
    def emulate_act_on_trainer(self, steps: PlayingStepsType,
                               transitions: Dict[str, Transition]) -> None:
        """
        This emulates the act using the transition obtained from the rollout worker on the training worker
        in case of distributed training.
        Do several steps of acting on the environment
        :param steps: the number of steps as a tuple of steps time and steps count
        """
        self.verify_graph_was_created()

        # perform several steps of playing
        count_end = self.current_step_counter + steps
        while self.current_step_counter < count_end:
            # reset the environment if the previous episode was terminated
            if self.reset_required:
                self.reset_internal_state()

            steps_begin = self.environments[0].total_steps_counter
            done = self.top_level_manager.emulate_step_on_trainer(transitions)
            steps_end = self.environments[0].total_steps_counter

            # add the diff between the total steps before and after stepping, such that environment initialization steps
            # (like in Atari) will not be counted.
            # We add at least one step so that even if no steps were made (in case no actions are taken in the training
            # phase), the loop will end eventually.
            self.current_step_counter[EnvironmentSteps] += max(
                1, steps_end - steps_begin)

            if done:
                self.handle_episode_ended()
                self.reset_required = True

    def fetch_from_worker(self, num_consecutive_playing_steps=None):
        if hasattr(self, 'memory_backend'):
            for transitions in self.memory_backend.fetch(
                    num_consecutive_playing_steps):
                self.emulate_act_on_trainer(EnvironmentSteps(1), transitions)
                if hasattr(self, 'sample_collector'):
                    self.sample_collector.sample(transitions)

    def setup_memory_backend(self) -> None:
        if hasattr(self, 'memory_backend_params'):
            self.memory_backend = deepracer_memory.DeepRacerTrainerBackEnd(
                self.memory_backend_params, self.agents_params)

    def should_stop(self) -> bool:
        return self.task_parameters.apply_stop_condition and all(
            [manager.should_stop() for manager in self.level_managers])

    def get_data_store(self, param):
        if self.data_store:
            return self.data_store

        return data_store_creator(param)

    def signal_ready(self):
        if self.task_parameters.checkpoint_save_dir and os.path.exists(
                self.task_parameters.checkpoint_save_dir):
            open(
                os.path.join(self.task_parameters.checkpoint_save_dir,
                             SyncFiles.TRAINER_READY.value), 'w').close()
        if hasattr(self, 'data_store_params'):
            data_store = self.get_data_store(self.data_store_params)
            data_store.save_to_store()

    def close(self) -> None:
        """
        Clean up to close environments.

        :return: None
        """
        for env in self.environments:
            env.close()

    def get_current_episodes_count(self):
        """
        Returns the current EnvironmentEpisodes counter
        """
        return self.current_step_counter[EnvironmentEpisodes]

    def flush_finished(self):
        """
        To indicate the training has finished, writes a `.finished` file to the checkpoint directory and calls
        the data store to updload that file.
        """
        if self.task_parameters.checkpoint_save_dir and os.path.exists(
                self.task_parameters.checkpoint_save_dir):
            open(
                os.path.join(self.task_parameters.checkpoint_save_dir,
                             SyncFiles.FINISHED.value), 'w').close()
        if hasattr(self, 'data_store_params'):
            data_store = self.get_data_store(self.data_store_params)
            data_store.save_to_store()

    def set_schedule_params(self, schedule_params: ScheduleParameters):
        """
        Set schedule parameters for the graph.

        :param schedule_params: the schedule params to set.
        """
        self.heatup_steps = schedule_params.heatup_steps
        self.evaluation_steps = schedule_params.evaluation_steps
        self.steps_between_evaluation_periods = schedule_params.steps_between_evaluation_periods
        self.improve_steps = schedule_params.improve_steps
Ejemplo n.º 2
0
class GraphManager(object):
    """
    A graph manager is responsible for creating and initializing a graph of agents, including all its internal
    components. It is then used in order to schedule the execution of operations on the graph, such as acting and
    training.
    """
    def __init__(
        self,
        name: str,
        schedule_params: ScheduleParameters,
        vis_params: VisualizationParameters = VisualizationParameters()):
        self.sess = None
        self.level_managers = []
        self.top_level_manager = None
        self.environments = []
        self.heatup_steps = schedule_params.heatup_steps
        self.evaluation_steps = schedule_params.evaluation_steps
        self.steps_between_evaluation_periods = schedule_params.steps_between_evaluation_periods
        self.improve_steps = schedule_params.improve_steps
        self.visualization_parameters = vis_params
        self.name = name
        self.task_parameters = None
        self._phase = self.phase = RunPhase.UNDEFINED
        self.preset_validation_params = PresetValidationParameters()
        self.reset_required = False

        # timers
        self.graph_initialization_time = time.time()
        self.graph_creation_time = None
        self.heatup_start_time = None
        self.training_start_time = None
        self.last_evaluation_start_time = None
        self.last_checkpoint_saving_time = time.time()

        # counters
        self.total_steps_counters = {
            RunPhase.HEATUP: TotalStepsCounter(),
            RunPhase.TRAIN: TotalStepsCounter(),
            RunPhase.TEST: TotalStepsCounter()
        }
        self.checkpoint_id = 0

        self.checkpoint_saver = None
        self.graph_logger = Logger()

    def create_graph(self, task_parameters: TaskParameters = TaskParameters()):
        self.graph_creation_time = time.time()
        self.task_parameters = task_parameters

        if isinstance(task_parameters, DistributedTaskParameters):
            screen.log_title(
                "Creating graph - name: {} task id: {} type: {}".format(
                    self.__class__.__name__, task_parameters.task_index,
                    task_parameters.job_type))
        else:
            screen.log_title("Creating graph - name: {}".format(
                self.__class__.__name__))

        # "hide" the gpu if necessary
        if task_parameters.use_cpu:
            set_cpu()

        # create a target server for the worker and a device
        if isinstance(task_parameters, DistributedTaskParameters):
            task_parameters.worker_target, task_parameters.device = \
                self.create_worker_or_parameters_server(task_parameters=task_parameters)

        # create the graph modules
        self.level_managers, self.environments = self._create_graph(
            task_parameters)

        # set self as the parent of all the level managers
        self.top_level_manager = self.level_managers[0]
        for level_manager in self.level_managers:
            level_manager.parent_graph_manager = self

        # create a session (it needs to be created after all the graph ops were created)
        self.sess = None
        self.create_session(task_parameters=task_parameters)

        self._phase = self.phase = RunPhase.UNDEFINED

        self.setup_logger()

        return self

    def _create_graph(
        self, task_parameters: TaskParameters
    ) -> Tuple[List[LevelManager], List[Environment]]:
        """
        Create all the graph modules and the graph scheduler
        :param task_parameters: the parameters of the task
        :return: the initialized level managers and environments
        """
        raise NotImplementedError("")

    def create_worker_or_parameters_server(
            self, task_parameters: DistributedTaskParameters):
        import tensorflow as tf
        config = tf.ConfigProto()
        config.allow_soft_placement = True  # allow placing ops on cpu if they are not fit for gpu
        config.gpu_options.allow_growth = True  # allow the gpu memory allocated for the worker to grow if needed
        config.gpu_options.per_process_gpu_memory_fraction = 0.2
        config.intra_op_parallelism_threads = 1
        config.inter_op_parallelism_threads = 1

        from rl_coach.architectures.tensorflow_components.distributed_tf_utils import create_and_start_parameters_server, \
            create_cluster_spec, create_worker_server_and_device

        # create cluster spec
        cluster_spec = create_cluster_spec(
            parameters_server=task_parameters.parameters_server_hosts,
            workers=task_parameters.worker_hosts)

        # create and start parameters server (non-returning function) or create a worker and a device setter
        if task_parameters.job_type == "ps":
            create_and_start_parameters_server(cluster_spec=cluster_spec,
                                               config=config)
        elif task_parameters.job_type == "worker":
            return create_worker_server_and_device(
                cluster_spec=cluster_spec,
                task_index=task_parameters.task_index,
                use_cpu=task_parameters.use_cpu,
                config=config)
        else:
            raise ValueError(
                "The job type should be either ps or worker and not {}".format(
                    task_parameters.job_type))

    def create_session(self, task_parameters: DistributedTaskParameters):
        import tensorflow as tf
        config = tf.ConfigProto()
        config.allow_soft_placement = True  # allow placing ops on cpu if they are not fit for gpu
        config.gpu_options.allow_growth = True  # allow the gpu memory allocated for the worker to grow if needed
        # config.gpu_options.per_process_gpu_memory_fraction = 0.2
        config.intra_op_parallelism_threads = 1
        config.inter_op_parallelism_threads = 1

        if isinstance(task_parameters, DistributedTaskParameters):
            # the distributed tensorflow setting
            from rl_coach.architectures.tensorflow_components.distributed_tf_utils import create_monitored_session
            if hasattr(self.task_parameters, 'checkpoint_restore_dir'
                       ) and self.task_parameters.checkpoint_restore_dir:
                checkpoint_dir = os.path.join(task_parameters.experiment_path,
                                              'checkpoint')
                if os.path.exists(checkpoint_dir):
                    remove_tree(checkpoint_dir)
                copy_tree(task_parameters.checkpoint_restore_dir,
                          checkpoint_dir)
            else:
                checkpoint_dir = task_parameters.save_checkpoint_dir

            self.sess = create_monitored_session(
                target=task_parameters.worker_target,
                task_index=task_parameters.task_index,
                checkpoint_dir=checkpoint_dir,
                save_checkpoint_secs=task_parameters.save_checkpoint_secs,
                config=config)
            # set the session for all the modules
            self.set_session(self.sess)
        else:
            self.variables_to_restore = tf.global_variables()
            # self.variables_to_restore = [v for v in self.variables_to_restore if '/online' in v.name] TODO: is this necessary?
            self.checkpoint_saver = tf.train.Saver(self.variables_to_restore)

            # regular session
            self.sess = tf.Session(config=config)

            # set the session for all the modules
            self.set_session(self.sess)

            # restore from checkpoint if given
            self.restore_checkpoint()

        # tf.train.write_graph(tf.get_default_graph(),
        #                      logdir=self.task_parameters.save_checkpoint_dir,
        #                      name='graphdef.pb',
        #                      as_text=False)
        # self.save_checkpoint()
        #
        # output_nodes = []
        # for level in self.level_managers:
        #     for agent in level.agents.values():
        #         for network in agent.networks.values():
        #             for output in network.online_network.outputs:
        #                 output_nodes.append(output.name.split(":")[0])
        #
        # freeze_graph_command = [
        #     "python -m tensorflow.python.tools.freeze_graph",
        #     "--input_graph={}".format(os.path.join(self.task_parameters.save_checkpoint_dir, "graphdef.pb")),
        #     "--input_binary=true",
        #     "--output_node_names='{}'".format(','.join(output_nodes)),
        #     "--input_checkpoint={}".format(os.path.join(self.task_parameters.save_checkpoint_dir, "0_Step-0.ckpt")),
        #     "--output_graph={}".format(os.path.join(self.task_parameters.save_checkpoint_dir, "frozen_graph.pb"))
        # ]
        # start_shell_command_and_wait(" ".join(freeze_graph_command))

    def setup_logger(self) -> None:
        # dump documentation
        logger_prefix = "{graph_name}".format(graph_name=self.name)
        self.graph_logger.set_logger_filenames(
            self.task_parameters.experiment_path,
            logger_prefix=logger_prefix,
            add_timestamp=True,
            task_id=self.task_parameters.task_index)
        if self.visualization_parameters.dump_parameters_documentation:
            self.graph_logger.dump_documentation(str(self))
        [manager.setup_logger() for manager in self.level_managers]

    @property
    def phase(self) -> RunPhase:
        """
        Get the phase of the graph
        :return: the current phase
        """
        return self._phase

    @phase.setter
    def phase(self, val: RunPhase):
        """
        Change the phase of the graph and all the hierarchy levels below it
        :param val: the new phase
        :return: None
        """
        self._phase = val
        for level_manager in self.level_managers:
            level_manager.phase = val
        for environment in self.environments:
            environment.phase = val

    def set_session(self, sess) -> None:
        """
        Set the deep learning framework session for all the modules in the graph
        :return: None
        """
        [manager.set_session(sess) for manager in self.level_managers]

    def heatup(self, steps: PlayingStepsType) -> None:
        """
        Perform heatup for several steps, which means taking random actions and storing the results in memory
        :param steps: the number of steps as a tuple of steps time and steps count
        :return: None
        """
        self.verify_graph_was_created()

        steps_copy = copy.copy(steps)

        if steps_copy.num_steps > 0:
            self.phase = RunPhase.HEATUP
            screen.log_title("{}: Starting heatup".format(self.name))
            self.heatup_start_time = time.time()

            # reset all the levels before starting to heatup
            self.reset_internal_state(force_environment_reset=True)

            # act on the environment
            while steps_copy.num_steps > 0:
                steps_done, _ = self.act(steps_copy,
                                         continue_until_game_over=True,
                                         return_on_game_over=True)
                steps_copy.num_steps -= steps_done

            # training phase
            self.phase = RunPhase.UNDEFINED

    def handle_episode_ended(self) -> None:
        """
        End an episode and reset all the episodic parameters
        :return: None
        """
        self.total_steps_counters[self.phase][EnvironmentEpisodes] += 1

        # TODO: we should disentangle ending the episode from resetting the internal state
        # self.reset_internal_state()

    def train(self, steps: TrainingSteps) -> None:
        """
        Perform several training iterations for all the levels in the hierarchy
        :param steps: number of training iterations to perform
        :return: None
        """
        self.verify_graph_was_created()

        # perform several steps of training interleaved with acting
        count_end = self.total_steps_counters[
            RunPhase.TRAIN][TrainingSteps] + steps.num_steps
        while self.total_steps_counters[
                RunPhase.TRAIN][TrainingSteps] < count_end:
            self.total_steps_counters[RunPhase.TRAIN][TrainingSteps] += 1
            [manager.train() for manager in self.level_managers]

    def reset_internal_state(self, force_environment_reset=False) -> None:
        """
        Reset an episode for all the levels
        :param force_environment_reset: force the environment to reset the episode even if it has some conditions that
                                        tell it not to. for example, if ale life is lost, gym will tell the agent that
                                        the episode is finished but won't actually reset the episode if there are more
                                        lives available
        :return: None
        """
        self.verify_graph_was_created()

        self.reset_required = False
        [
            environment.reset_internal_state(force_environment_reset)
            for environment in self.environments
        ]
        [manager.reset_internal_state() for manager in self.level_managers]

    def act(self,
            steps: PlayingStepsType,
            return_on_game_over: bool = False,
            continue_until_game_over=False,
            keep_networks_in_sync=False) -> (int, bool):
        """
        Do several steps of acting on the environment
        :param steps: the number of steps as a tuple of steps time and steps count
        :param return_on_game_over: finish acting if an episode is finished
        :param continue_until_game_over: continue playing until an episode was completed
        :param keep_networks_in_sync: sync the network parameters with the global network before each episode
        :return: the actual number of steps done, a boolean value that represent if the episode was done when finishing
                 the function call
        """
        self.verify_graph_was_created()

        # perform several steps of playing
        result = None

        hold_until_a_full_episode = True if continue_until_game_over else False
        initial_count = self.total_steps_counters[self.phase][steps.__class__]
        count_end = initial_count + steps.num_steps

        # The assumption here is that the total_steps_counters are each updated when an event
        #  takes place (i.e. an episode ends)
        # TODO - The counter of frames is not updated correctly. need to fix that.
        while self.total_steps_counters[self.phase][
                steps.__class__] < count_end or hold_until_a_full_episode:
            # reset the environment if the previous episode was terminated
            if self.reset_required:
                self.reset_internal_state()

            current_steps = self.environments[0].total_steps_counter

            result = self.top_level_manager.step(None)
            # result will be None if at least one level_manager decided not to play (= all of his agents did not play)
            # causing the rest of the level_managers down the stack not to play either, and thus the entire graph did
            # not act
            if result is None:
                break

            # add the diff between the total steps before and after stepping, such that environment initialization steps
            # (like in Atari) will not be counted.
            # We add at least one step so that even if no steps were made (in case no actions are taken in the training
            # phase), the loop will end eventually.
            self.total_steps_counters[self.phase][EnvironmentSteps] += \
                max(1, self.environments[0].total_steps_counter - current_steps)

            if result.game_over:
                hold_until_a_full_episode = False
                self.handle_episode_ended()
                self.reset_required = True
                if keep_networks_in_sync:
                    self.sync_graph()
                if return_on_game_over:
                    return self.total_steps_counters[
                        self.phase][EnvironmentSteps] - initial_count, True

        # return the game over status
        if result:
            return self.total_steps_counters[
                self.phase][EnvironmentSteps] - initial_count, result.game_over
        else:
            return self.total_steps_counters[
                self.phase][EnvironmentSteps] - initial_count, False

    def train_and_act(self, steps: StepMethod) -> None:
        """
        Train the agent by doing several acting steps followed by several training steps continually
        :param steps: the number of steps as a tuple of steps time and steps count
        :return: None
        """
        self.verify_graph_was_created()

        # perform several steps of training interleaved with acting
        if steps.num_steps > 0:
            self.phase = RunPhase.TRAIN
            count_end = self.total_steps_counters[self.phase][
                steps.__class__] + steps.num_steps
            self.reset_internal_state(force_environment_reset=True)
            #TODO - the below while loop should end with full episodes, so to avoid situations where we have partial
            #  episodes in memory
            while self.total_steps_counters[self.phase][
                    steps.__class__] < count_end:
                # The actual steps being done on the environment are decided by the agents themselves.
                # This is just an high-level controller.
                self.act(EnvironmentSteps(1))
                self.train(TrainingSteps(1))
                self.save_checkpoint()
            self.phase = RunPhase.UNDEFINED

    def sync_graph(self) -> None:
        """
        Sync the global network parameters to the graph
        :return:
        """
        [manager.sync() for manager in self.level_managers]

    def evaluate(self,
                 steps: PlayingStepsType,
                 keep_networks_in_sync: bool = False) -> None:
        """
        Perform evaluation for several steps
        :param steps: the number of steps as a tuple of steps time and steps count
        :param keep_networks_in_sync: sync the network parameters with the global network before each episode
        :return: None
        """
        self.verify_graph_was_created()

        if steps.num_steps > 0:
            self.phase = RunPhase.TEST
            self.last_evaluation_start_time = time.time()

            # reset all the levels before starting to evaluate
            self.reset_internal_state(force_environment_reset=True)
            self.sync_graph()

            count_end = self.total_steps_counters[self.phase][
                steps.__class__] + steps.num_steps
            while self.total_steps_counters[self.phase][
                    steps.__class__] < count_end:
                steps_done, _ = self.act(
                    steps,
                    continue_until_game_over=True,
                    return_on_game_over=True,
                    keep_networks_in_sync=keep_networks_in_sync)

            self.phase = RunPhase.UNDEFINED

    def restore_checkpoint(self):
        self.verify_graph_was_created()

        # TODO: find better way to load checkpoints that were saved with a global network into the online network
        if hasattr(self.task_parameters, 'checkpoint_restore_dir'
                   ) and self.task_parameters.checkpoint_restore_dir:
            import tensorflow as tf
            checkpoint_dir = self.task_parameters.checkpoint_restore_dir
            checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
            screen.log_title("Loading checkpoint: {}".format(
                checkpoint.model_checkpoint_path))
            variables = {}
            for var_name, _ in tf.contrib.framework.list_variables(
                    self.task_parameters.checkpoint_restore_dir):
                # Load the variable
                var = tf.contrib.framework.load_variable(
                    checkpoint_dir, var_name)

                # Set the new name
                new_name = var_name
                new_name = new_name.replace('global/', 'online/')
                variables[new_name] = var

            for v in self.variables_to_restore:
                self.sess.run(v.assign(variables[v.name.split(':')[0]]))

    def save_checkpoint(self):
        # only the chief process saves checkpoints
        if self.task_parameters.save_checkpoint_secs \
                and time.time() - self.last_checkpoint_saving_time >= self.task_parameters.save_checkpoint_secs \
                and (self.task_parameters.task_index == 0  # distributed
                     or self.task_parameters.task_index is None  # single-worker
                     ):

            checkpoint_path = os.path.join(
                self.task_parameters.save_checkpoint_dir,
                "{}_Step-{}.ckpt".format(
                    self.checkpoint_id, self.total_steps_counters[
                        RunPhase.TRAIN][EnvironmentSteps]))
            if not isinstance(self.task_parameters, DistributedTaskParameters):
                saved_checkpoint_path = self.checkpoint_saver.save(
                    self.sess, checkpoint_path)
            else:
                saved_checkpoint_path = checkpoint_path

            # this is required in order for agents to save additional information like a DND for example
            [
                manager.save_checkpoint(self.checkpoint_id)
                for manager in self.level_managers
            ]

            screen.log_dict(OrderedDict([
                ("Saving in path", saved_checkpoint_path),
            ]),
                            prefix="Checkpoint")

            self.checkpoint_id += 1
            self.last_checkpoint_saving_time = time.time()

    def improve(self):
        """
        The main loop of the run.
        Defined in the following steps:
        1. Heatup
        2. Repeat:
            2.1. Repeat:
                2.1.1. Act
                2.1.2. Train
                2.1.3. Possibly save checkpoint
            2.2. Evaluate
        :return: None
        """
        self.verify_graph_was_created()

        # initialize the network parameters from the global network
        self.sync_graph()

        # heatup
        self.heatup(self.heatup_steps)

        # improve
        if self.task_parameters.task_index is not None:
            screen.log_title("Starting to improve {} task index {}".format(
                self.name, self.task_parameters.task_index))
        else:
            screen.log_title("Starting to improve {}".format(self.name))
        self.training_start_time = time.time()
        count_end = self.improve_steps.num_steps
        while self.total_steps_counters[RunPhase.TRAIN][
                self.improve_steps.__class__] < count_end:
            self.train_and_act(self.steps_between_evaluation_periods)
            self.evaluate(self.evaluation_steps)

    def verify_graph_was_created(self):
        """
        Verifies that the graph was already created, and if not, it creates it with the default task parameters
        :return: None
        """
        if self.graph_creation_time is None:
            self.create_graph()

    def __str__(self):
        result = ""
        for key, val in self.__dict__.items():
            params = ""
            if isinstance(val, list) or isinstance(val, dict) or isinstance(
                    val, OrderedDict):
                items = iterable_to_items(val)
                for k, v in items:
                    params += "{}: {}\n".format(k, v)
            else:
                params = val
            result += "{}: \n{}\n".format(key, params)

        return result
Ejemplo n.º 3
0
class Agent(AgentInterface):
    def __init__(self,
                 agent_parameters: AgentParameters,
                 parent: Union['LevelManager', 'CompositeAgent'] = None):
        """
        :param agent_parameters: A Preset class instance with all the running paramaters
        """
        super().__init__()
        self.ap = agent_parameters
        self.task_id = self.ap.task_parameters.task_index
        self.is_chief = self.task_id == 0
        self.shared_memory = type(agent_parameters.task_parameters) == DistributedTaskParameters \
                             and self.ap.memory.shared_memory
        if self.shared_memory:
            self.shared_memory_scratchpad = self.ap.task_parameters.shared_memory_scratchpad
        self.name = agent_parameters.name
        self.parent = parent
        self.parent_level_manager = None
        self.full_name_id = agent_parameters.full_name_id = self.name

        if type(agent_parameters.task_parameters) == DistributedTaskParameters:
            screen.log_title(
                "Creating agent - name: {} task id: {} (may take up to 30 seconds due to "
                "tensorflow wake up time)".format(self.full_name_id,
                                                  self.task_id))
        else:
            screen.log_title("Creating agent - name: {}".format(
                self.full_name_id))
        self.imitation = False
        self.agent_logger = Logger()
        self.agent_episode_logger = EpisodeLogger()

        # get the memory
        # - distributed training + shared memory:
        #   * is chief?  -> create the memory and add it to the scratchpad
        #   * not chief? -> wait for the chief to create the memory and then fetch it
        # - non distributed training / not shared memory:
        #   * create memory
        memory_name = self.ap.memory.path.split(':')[1]
        self.memory_lookup_name = self.full_name_id + '.' + memory_name
        if self.shared_memory and not self.is_chief:
            self.memory = self.shared_memory_scratchpad.get(
                self.memory_lookup_name)
        else:
            # modules
            if agent_parameters.memory.load_memory_from_file_path:
                screen.log_title(
                    "Loading replay buffer from pickle. Pickle path: {}".
                    format(agent_parameters.memory.load_memory_from_file_path))
                self.memory = read_pickle(
                    agent_parameters.memory.load_memory_from_file_path)
            else:
                self.memory = dynamic_import_and_instantiate_module_from_params(
                    self.ap.memory)

            if self.shared_memory and self.is_chief:
                self.shared_memory_scratchpad.add(self.memory_lookup_name,
                                                  self.memory)

        # set devices
        if type(agent_parameters.task_parameters) == DistributedTaskParameters:
            self.has_global = True
            self.replicated_device = agent_parameters.task_parameters.device
            self.worker_device = "/job:worker/task:{}".format(self.task_id)
        else:
            self.has_global = False
            self.replicated_device = None
            self.worker_device = ""
        if agent_parameters.task_parameters.use_cpu:
            self.worker_device += "/cpu:0"
        else:
            self.worker_device += "/device:GPU:0"

        # filters
        self.input_filter = self.ap.input_filter
        self.output_filter = self.ap.output_filter
        self.pre_network_filter = self.ap.pre_network_filter
        device = self.replicated_device if self.replicated_device else self.worker_device
        self.input_filter.set_device(device)
        self.output_filter.set_device(device)
        self.pre_network_filter.set_device(device)

        # initialize all internal variables
        self._phase = RunPhase.HEATUP
        self.total_shaped_reward_in_current_episode = 0
        self.total_reward_in_current_episode = 0
        self.total_steps_counter = 0
        self.running_reward = None
        self.training_iteration = 0
        self.last_target_network_update_step = 0
        self.last_training_phase_step = 0
        self.current_episode = self.ap.current_episode = 0
        self.curr_state = {}
        self.current_hrl_goal = None
        self.current_episode_steps_counter = 0
        self.episode_running_info = {}
        self.last_episode_evaluation_ran = 0
        self.running_observations = []
        self.agent_logger.set_current_time(self.current_episode)
        self.exploration_policy = None
        self.networks = {}
        self.last_action_info = None
        self.running_observation_stats = None
        self.running_reward_stats = None
        self.accumulated_rewards_across_evaluation_episodes = 0
        self.accumulated_shaped_rewards_across_evaluation_episodes = 0
        self.num_successes_across_evaluation_episodes = 0
        self.num_evaluation_episodes_completed = 0
        self.current_episode_buffer = Episode(
            discount=self.ap.algorithm.discount)
        # TODO: add agents observation rendering for debugging purposes (not the same as the environment rendering)

        # environment parameters
        self.spaces = None
        self.in_action_space = self.ap.algorithm.in_action_space

        # signals
        self.episode_signals = []
        self.step_signals = []
        self.loss = self.register_signal('Loss')
        self.curr_learning_rate = self.register_signal('Learning Rate')
        self.unclipped_grads = self.register_signal('Grads (unclipped)')
        self.reward = self.register_signal('Reward',
                                           dump_one_value_per_episode=False,
                                           dump_one_value_per_step=True)
        self.shaped_reward = self.register_signal(
            'Shaped Reward',
            dump_one_value_per_episode=False,
            dump_one_value_per_step=True)
        if isinstance(self.in_action_space, GoalsSpace):
            self.distance_from_goal = self.register_signal(
                'Distance From Goal', dump_one_value_per_step=True)

        # use seed
        if self.ap.task_parameters.seed is not None:
            random.seed(self.ap.task_parameters.seed)
            np.random.seed(self.ap.task_parameters.seed)

    @property
    def parent(self):
        """
        Get the parent class of the agent
        :return: the current phase
        """
        return self._parent

    @parent.setter
    def parent(self, val):
        """
        Change the parent class of the agent.
        Additionally, updates the full name of the agent
        :param val: the new parent
        :return: None
        """
        self._parent = val
        if self._parent is not None:
            if not hasattr(self._parent, 'name'):
                raise ValueError("The parent of an agent must have a name")
            self.full_name_id = self.ap.full_name_id = "{}/{}".format(
                self._parent.name, self.name)

    def setup_logger(self):
        # dump documentation
        logger_prefix = "{graph_name}.{level_name}.{agent_full_id}".\
            format(graph_name=self.parent_level_manager.parent_graph_manager.name,
                   level_name=self.parent_level_manager.name,
                   agent_full_id='.'.join(self.full_name_id.split('/')))
        self.agent_logger.set_logger_filenames(
            self.ap.task_parameters.experiment_path,
            logger_prefix=logger_prefix,
            add_timestamp=True,
            task_id=self.task_id)
        if self.ap.visualization.dump_in_episode_signals:
            self.agent_episode_logger.set_logger_filenames(
                self.ap.task_parameters.experiment_path,
                logger_prefix=logger_prefix,
                add_timestamp=True,
                task_id=self.task_id)

    def set_session(self, sess) -> None:
        """
        Set the deep learning framework session for all the agents in the composite agent
        :return: None
        """
        self.input_filter.set_session(sess)
        self.output_filter.set_session(sess)
        self.pre_network_filter.set_session(sess)
        [network.set_session(sess) for network in self.networks.values()]

    def register_signal(self,
                        signal_name: str,
                        dump_one_value_per_episode: bool = True,
                        dump_one_value_per_step: bool = False) -> Signal:
        """
        Register a signal such that its statistics will be dumped and be viewable through dashboard
        :param signal_name: the name of the signal as it will appear in dashboard
        :param dump_one_value_per_episode: should the signal value be written for each episode?
        :param dump_one_value_per_step: should the signal value be written for each step?
        :return: the created signal
        """
        signal = Signal(signal_name)
        if dump_one_value_per_episode:
            self.episode_signals.append(signal)
        if dump_one_value_per_step:
            self.step_signals.append(signal)
        return signal

    def set_environment_parameters(self, spaces: SpacesDefinition):
        """
        Sets the parameters that are environment dependent. As a side effect, initializes all the components that are
        dependent on those values, by calling init_environment_dependent_modules
        :param spaces: the environment spaces definition
        :return: None
        """
        self.spaces = copy.deepcopy(spaces)

        if self.ap.algorithm.use_accumulated_reward_as_measurement:
            if 'measurements' in self.spaces.state.sub_spaces:
                self.spaces.state['measurements'].shape += 1
                self.spaces.state['measurements'].measurements_names += [
                    'accumulated_reward'
                ]
            else:
                self.spaces.state['measurements'] = VectorObservationSpace(
                    1, measurements_names=['accumulated_reward'])

        for observation_name in self.spaces.state.sub_spaces.keys():
            self.spaces.state[observation_name] = \
                self.pre_network_filter.get_filtered_observation_space(observation_name,
                    self.input_filter.get_filtered_observation_space(observation_name,
                                                                     self.spaces.state[observation_name]))

        self.spaces.reward = self.pre_network_filter.get_filtered_reward_space(
            self.input_filter.get_filtered_reward_space(self.spaces.reward))

        self.spaces.action = self.output_filter.get_unfiltered_action_space(
            self.spaces.action)

        if isinstance(self.in_action_space, GoalsSpace):
            # TODO: what if the goal type is an embedding / embedding change?
            self.spaces.goal = self.in_action_space
            self.spaces.goal.set_target_space(
                self.spaces.state[self.spaces.goal.goal_name])

        self.init_environment_dependent_modules()

    def create_networks(self) -> Dict[str, NetworkWrapper]:
        """
        Create all the networks of the agent.
        The network creation will be done after setting the environment parameters for the agent, since they are needed
        for creating the network.
        :return: A list containing all the networks
        """
        networks = {}
        for network_name in sorted(self.ap.network_wrappers.keys()):
            networks[network_name] = NetworkWrapper(
                name=network_name,
                agent_parameters=self.ap,
                has_target=self.ap.network_wrappers[network_name].
                create_target_network,
                has_global=self.has_global,
                spaces=self.spaces,
                replicated_device=self.replicated_device,
                worker_device=self.worker_device)
        return networks

    def init_environment_dependent_modules(self) -> None:
        """
        Initialize any modules that depend on knowing information about the environment such as the action space or
        the observation space
        :return: None
        """
        # initialize exploration policy
        self.ap.exploration.action_space = self.spaces.action
        self.exploration_policy = dynamic_import_and_instantiate_module_from_params(
            self.ap.exploration)

        # create all the networks of the agent
        self.networks = self.create_networks()

    @property
    def phase(self) -> RunPhase:
        return self._phase

    @phase.setter
    def phase(self, val: RunPhase) -> None:
        """
        Change the phase of the run for the agent and all the sub components
        :param phase: the new run phase (TRAIN, TEST, etc.)
        :return: None
        """
        self.reset_evaluation_state(val)
        self._phase = val
        self.exploration_policy.change_phase(val)

    def reset_evaluation_state(self, val: RunPhase) -> None:
        starting_evaluation = (val == RunPhase.TEST)
        ending_evaluation = (self.phase == RunPhase.TEST)

        if starting_evaluation:
            self.accumulated_rewards_across_evaluation_episodes = 0
            self.accumulated_shaped_rewards_across_evaluation_episodes = 0
            self.num_successes_across_evaluation_episodes = 0
            self.num_evaluation_episodes_completed = 0
            if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":
                screen.log_title("{}: Starting evaluation phase".format(
                    self.name))

        elif ending_evaluation:
            # we write to the next episode, because it could be that the current episode was already written
            # to disk and then we won't write it again
            self.agent_logger.set_current_time(self.current_episode + 1)
            self.agent_logger.create_signal_value(
                'Evaluation Reward',
                self.accumulated_rewards_across_evaluation_episodes /
                self.num_evaluation_episodes_completed)
            self.agent_logger.create_signal_value(
                'Shaped Evaluation Reward',
                self.accumulated_shaped_rewards_across_evaluation_episodes /
                self.num_evaluation_episodes_completed)
            success_rate = self.num_successes_across_evaluation_episodes / self.num_evaluation_episodes_completed
            self.agent_logger.create_signal_value("Success Rate", success_rate)
            if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":
                screen.log_title(
                    "{}: Finished evaluation phase. Success rate = {}".format(
                        self.name, np.round(success_rate, 2)))

    def call_memory(self, func, args=()):
        """
        This function is a wrapper to allow having the same calls for shared or unshared memories.
        It should be used instead of calling the memory directly in order to allow different algorithms to work
        both with a shared and a local memory.
        :param func: the name of the memory function to call
        :param args: the arguments to supply to the function
        :return: the return value of the function
        """
        if self.shared_memory:
            result = self.shared_memory_scratchpad.internal_call(
                self.memory_lookup_name, func, args)
        else:
            if type(args) != tuple:
                args = (args, )
            result = getattr(self.memory, func)(*args)
        return result

    def log_to_screen(self):
        # log to screen
        log = OrderedDict()
        log["Name"] = self.full_name_id
        if self.task_id is not None:
            log["Worker"] = self.task_id
        log["Episode"] = self.current_episode
        log["Total reward"] = np.round(self.total_reward_in_current_episode, 2)
        log["Exploration"] = np.round(
            self.exploration_policy.get_control_param(), 2)
        log["Steps"] = self.total_steps_counter
        log["Training iteration"] = self.training_iteration
        screen.log_dict(log, prefix=self.phase.value)

    def update_step_in_episode_log(self):
        """
        Writes logging messages to screen and updates the log file with all the signal values.
        :return: None
        """
        # log all the signals to file
        self.agent_episode_logger.set_current_time(
            self.current_episode_steps_counter)
        self.agent_episode_logger.create_signal_value('Training Iter',
                                                      self.training_iteration)
        self.agent_episode_logger.create_signal_value(
            'In Heatup', int(self._phase == RunPhase.HEATUP))
        self.agent_episode_logger.create_signal_value(
            'ER #Transitions', self.call_memory('num_transitions'))
        self.agent_episode_logger.create_signal_value(
            'ER #Episodes', self.call_memory('length'))
        self.agent_episode_logger.create_signal_value('Total steps',
                                                      self.total_steps_counter)
        self.agent_episode_logger.create_signal_value(
            "Epsilon", self.exploration_policy.get_control_param())
        self.agent_episode_logger.create_signal_value(
            "Shaped Accumulated Reward",
            self.total_shaped_reward_in_current_episode)
        self.agent_episode_logger.create_signal_value('Update Target Network',
                                                      0,
                                                      overwrite=False)
        self.agent_episode_logger.update_wall_clock_time(
            self.current_episode_steps_counter)

        for signal in self.step_signals:
            self.agent_episode_logger.create_signal_value(
                signal.name, signal.get_last_value())

        # dump
        self.agent_episode_logger.dump_output_csv()

    def update_log(self):
        """
        Writes logging messages to screen and updates the log file with all the signal values.
        :return: None
        """
        # log all the signals to file
        self.agent_logger.set_current_time(self.current_episode)
        self.agent_logger.create_signal_value('Training Iter',
                                              self.training_iteration)
        self.agent_logger.create_signal_value(
            'In Heatup', int(self._phase == RunPhase.HEATUP))
        self.agent_logger.create_signal_value(
            'ER #Transitions', self.call_memory('num_transitions'))
        self.agent_logger.create_signal_value('ER #Episodes',
                                              self.call_memory('length'))
        self.agent_logger.create_signal_value(
            'Episode Length', self.current_episode_steps_counter)
        self.agent_logger.create_signal_value('Total steps',
                                              self.total_steps_counter)
        self.agent_logger.create_signal_value(
            "Epsilon", np.mean(self.exploration_policy.get_control_param()))
        self.agent_logger.create_signal_value(
            "Shaped Training Reward",
            self.total_shaped_reward_in_current_episode
            if self._phase == RunPhase.TRAIN else np.nan)
        self.agent_logger.create_signal_value(
            "Training Reward", self.total_reward_in_current_episode
            if self._phase == RunPhase.TRAIN else np.nan)

        self.agent_logger.create_signal_value('Update Target Network',
                                              0,
                                              overwrite=False)
        self.agent_logger.update_wall_clock_time(self.current_episode)

        if self._phase != RunPhase.TEST:
            self.agent_logger.create_signal_value('Evaluation Reward',
                                                  np.nan,
                                                  overwrite=False)
            self.agent_logger.create_signal_value('Shaped Evaluation Reward',
                                                  np.nan,
                                                  overwrite=False)
            self.agent_logger.create_signal_value('Success Rate',
                                                  np.nan,
                                                  overwrite=False)

        for signal in self.episode_signals:
            self.agent_logger.create_signal_value(
                "{}/Mean".format(signal.name), signal.get_mean())
            self.agent_logger.create_signal_value(
                "{}/Stdev".format(signal.name), signal.get_stdev())
            self.agent_logger.create_signal_value("{}/Max".format(signal.name),
                                                  signal.get_max())
            self.agent_logger.create_signal_value("{}/Min".format(signal.name),
                                                  signal.get_min())

        # dump
        if self.current_episode % self.ap.visualization.dump_signals_to_csv_every_x_episodes == 0 \
                and self.current_episode > 0:
            self.agent_logger.dump_output_csv()

    def handle_episode_ended(self) -> None:
        """
        End an episode
        :return: None
        """
        self.current_episode_buffer.is_complete = True

        if self.phase != RunPhase.TEST or self.ap.task_parameters.evaluate_only:
            self.current_episode += 1

        if self.phase != RunPhase.TEST and isinstance(
                self.memory, EpisodicExperienceReplay):
            self.call_memory('store_episode', self.current_episode_buffer)

        if self.phase == RunPhase.TEST:
            self.accumulated_rewards_across_evaluation_episodes += self.total_reward_in_current_episode
            self.accumulated_shaped_rewards_across_evaluation_episodes += self.total_shaped_reward_in_current_episode
            self.num_evaluation_episodes_completed += 1

            if self.spaces.reward.reward_success_threshold and \
                    self.total_reward_in_current_episode >= self.spaces.reward.reward_success_threshold:
                self.num_successes_across_evaluation_episodes += 1

        if self.ap.visualization.dump_csv:
            self.update_log()

        if self.ap.is_a_highest_level_agent or self.ap.task_parameters.verbosity == "high":
            self.log_to_screen()

    def reset_internal_state(self):
        """
        Reset all the episodic parameters
        :return: None
        """
        for signal in self.episode_signals:
            signal.reset()
        for signal in self.step_signals:
            signal.reset()
        self.agent_episode_logger.set_episode_idx(self.current_episode)
        self.total_shaped_reward_in_current_episode = 0
        self.total_reward_in_current_episode = 0
        self.curr_state = {}
        self.current_episode_steps_counter = 0
        self.episode_running_info = {}
        self.current_episode_buffer = Episode(
            discount=self.ap.algorithm.discount)
        if self.exploration_policy:
            self.exploration_policy.reset()
        self.input_filter.reset()
        self.output_filter.reset()
        self.pre_network_filter.reset()
        if isinstance(self.memory, EpisodicExperienceReplay):
            self.call_memory('verify_last_episode_is_closed')

        for network in self.networks.values():
            network.online_network.reset_internal_memory()

    def learn_from_batch(self, batch) -> Tuple[float, List, List]:
        """
        Given a batch of transitions, calculates their target values and updates the network.
        :param batch: A list of transitions
        :return: The total loss of the training, the loss per head and the unclipped gradients
        """
        return 0, [], []

    def _should_update_online_weights_to_target(self):
        """
        Determine if online weights should be copied to the target.
        :return: boolean: True if the online weights should be copied to the target.
        """
        # update the target network of every network that has a target network
        step_method = self.ap.algorithm.num_steps_between_copying_online_weights_to_target
        if step_method.__class__ == TrainingSteps:
            should_update = (
                self.training_iteration -
                self.last_target_network_update_step) >= step_method.num_steps
            if should_update:
                self.last_target_network_update_step = self.training_iteration
        elif step_method.__class__ == EnvironmentSteps:
            should_update = (
                self.total_steps_counter -
                self.last_target_network_update_step) >= step_method.num_steps
            if should_update:
                self.last_target_network_update_step = self.total_steps_counter
        else:
            raise ValueError(
                "The num_steps_between_copying_online_weights_to_target parameter should be either "
                "EnvironmentSteps or TrainingSteps. Instead it is {}".format(
                    step_method.__class__))
        return should_update

    def _should_train(self, wait_for_full_episode=False):
        """
        Determine if we should start a training phase according to the number of steps passed since the last training
        :return:  boolean: True if we should start a training phase
        """
        step_method = self.ap.algorithm.num_consecutive_playing_steps
        if step_method.__class__ == EnvironmentEpisodes:
            should_update = (
                self.current_episode -
                self.last_training_phase_step) >= step_method.num_steps
            if should_update:
                self.last_training_phase_step = self.current_episode
        elif step_method.__class__ == EnvironmentSteps:
            should_update = (
                self.total_steps_counter -
                self.last_training_phase_step) >= step_method.num_steps
            if wait_for_full_episode:
                should_update = should_update and self.current_episode_steps_counter == 0
            if should_update:
                self.last_training_phase_step = self.total_steps_counter
        else:
            raise ValueError(
                "The num_consecutive_playing_steps parameter should be either "
                "EnvironmentSteps or Episodes. Instead it is {}".format(
                    step_method.__class__))
        return should_update

    def train(self):
        """
        Check if a training phase should be done as configured by num_consecutive_playing_steps.
        If it should, then do several training steps as configured by num_consecutive_training_steps.
        A single training iteration: Sample a batch, train on it and update target networks.
        :return: The total training loss during the training iterations.
        """
        loss = 0
        if self._should_train():
            for training_step in range(
                    self.ap.algorithm.num_consecutive_training_steps):
                # TODO: this should be network dependent
                network_parameters = list(self.ap.network_wrappers.values())[0]

                # update counters
                self.training_iteration += 1

                # sample a batch and train on it
                batch = self.call_memory('sample',
                                         network_parameters.batch_size)
                if self.pre_network_filter is not None:
                    batch = self.pre_network_filter.filter(
                        batch, update_internal_state=False, deep_copy=False)

                # if the batch returned empty then there are not enough samples in the replay buffer -> skip
                # training step
                if len(batch) > 0:
                    # train
                    batch = Batch(batch)
                    total_loss, losses, unclipped_grads = self.learn_from_batch(
                        batch)
                    loss += total_loss
                    self.unclipped_grads.add_sample(unclipped_grads)

                    # TODO: the learning rate decay should be done through the network instead of here
                    # decay learning rate
                    if network_parameters.learning_rate_decay_rate != 0:
                        self.curr_learning_rate.add_sample(
                            self.networks['main'].sess.run(
                                self.networks['main'].online_network.
                                current_learning_rate))
                    else:
                        self.curr_learning_rate.add_sample(
                            network_parameters.learning_rate)

                    if any([network.has_target for network in self.networks.values()]) \
                            and self._should_update_online_weights_to_target():
                        for network in self.networks.values():
                            network.update_target_network(
                                self.ap.algorithm.
                                rate_for_copying_weights_to_target)

                        self.agent_logger.create_signal_value(
                            'Update Target Network', 1)
                    else:
                        self.agent_logger.create_signal_value(
                            'Update Target Network', 0, overwrite=False)

                    self.loss.add_sample(loss)

                    if self.imitation:
                        self.log_to_screen()

            # run additional commands after the training is done
            self.post_training_commands()

        return loss

    def choose_action(self, curr_state):
        """
        choose an action to act with in the current episode being played. Different behavior might be exhibited when training
         or testing.

        :param curr_state: the current state to act upon.
        :return: chosen action, some action value describing the action (q-value, probability, etc)
        """
        pass

    def prepare_batch_for_inference(self,
                                    states: Union[Dict[str, np.ndarray],
                                                  List[Dict[str, np.ndarray]]],
                                    network_name: str):
        """
        convert curr_state into input tensors tensorflow is expecting. i.e. if we have several inputs states, stack all
        observations together, measurements together, etc.
        """
        # convert to batch so we can run it through the network
        states = force_list(states)
        batches_dict = {}
        for key in self.ap.network_wrappers[
                network_name].input_embedders_parameters.keys():
            # there are cases (e.g. ddpg) where the state does not contain all the information needed for running
            # through the network and this has to be added externally (e.g. ddpg where the action needs to be given in
            # addition to the current_state, so that all the inputs of the network will be filled)
            if key in states[0].keys():
                batches_dict[key] = np.array(
                    [np.array(state[key]) for state in states])

        return batches_dict

    def act(self) -> ActionInfo:
        """
        Given the agents current knowledge, decide on the next action to apply to the environment
        :return: an action and a dictionary containing any additional info from the action decision process
        """
        if self.phase == RunPhase.TRAIN and self.ap.algorithm.num_consecutive_playing_steps.num_steps == 0:
            # This agent never plays  while training (e.g. behavioral cloning)
            return None

        # count steps (only when training or if we are in the evaluation worker)
        if self.phase != RunPhase.TEST or self.ap.task_parameters.evaluate_only:
            self.total_steps_counter += 1
        self.current_episode_steps_counter += 1

        # decide on the action
        if self.phase == RunPhase.HEATUP and not self.ap.algorithm.heatup_using_network_decisions:
            # random action
            self.last_action_info = self.spaces.action.sample_with_info()
        else:
            # informed action
            if self.pre_network_filter is not None:
                # before choosing an action, first use the pre_network_filter to filter out the current state
                curr_state = self.run_pre_network_filter_for_inference(
                    self.curr_state)

            else:
                curr_state = self.curr_state
            self.last_action_info = self.choose_action(curr_state)

        filtered_action_info = self.output_filter.filter(self.last_action_info)

        return filtered_action_info

    def run_pre_network_filter_for_inference(self, state: StateType):
        dummy_env_response = EnvResponse(next_state=state,
                                         reward=0,
                                         game_over=False)
        return self.pre_network_filter.filter(dummy_env_response)[0].next_state

    def get_state_embedding(self, state: dict) -> np.ndarray:
        """
        Given a state, get the corresponding state embedding  from the main network
        :param state: a state dict
        :return: a numpy embedding vector
        """
        # TODO: this won't work anymore
        # TODO: instead of the state embedding (which contains the goal) we should use the observation embedding
        embedding = self.networks['main'].online_network.predict(
            self.prepare_batch_for_inference(state, "main"),
            outputs=self.networks['main'].online_network.state_embedding)
        return embedding

    def update_transition_before_adding_to_replay_buffer(
            self, transition: Transition) -> Transition:
        """
        Allows agents to update the transition just before adding it to the replay buffer.
        Can be useful for agents that want to tweak the reward, termination signal, etc.
        :param transition: the transition to update
        :return: the updated transition
        """
        return transition

    def observe(self, env_response: EnvResponse) -> bool:
        """
        Given a response from the environment, distill the observation from it and store it for later use.
        The response should be a dictionary containing the performed action, the new observation and measurements,
        the reward, a game over flag and any additional information necessary.
        :param env_response: result of call from environment.step(action)
        :return:
        """

        # filter the env_response
        filtered_env_response = self.input_filter.filter(env_response)[0]

        # inject agent collected statistics, if required
        if self.ap.algorithm.use_accumulated_reward_as_measurement:
            if 'measurements' in filtered_env_response.next_state:
                filtered_env_response.next_state['measurements'] = np.append(
                    filtered_env_response.next_state['measurements'],
                    self.total_shaped_reward_in_current_episode)
            else:
                filtered_env_response.next_state['measurements'] = np.array(
                    [self.total_shaped_reward_in_current_episode])

        # if we are in the first step in the episode, then we don't have a a next state and a reward and thus no
        # transition yet, and therefore we don't need to store anything in the memory.
        # also we did not reach the goal yet.
        if self.current_episode_steps_counter == 0:
            # initialize the current state
            self.curr_state = filtered_env_response.next_state
            return env_response.game_over
        else:
            transition = Transition(
                state=copy.copy(self.curr_state),
                action=self.last_action_info.action,
                reward=filtered_env_response.reward,
                next_state=filtered_env_response.next_state,
                game_over=filtered_env_response.game_over,
                info=filtered_env_response.info)

            # now that we have formed a basic transition - the next state progresses to be the current state
            self.curr_state = filtered_env_response.next_state

            # make agent specific changes to the transition if needed
            transition = self.update_transition_before_adding_to_replay_buffer(
                transition)

            # merge the intrinsic reward in
            if self.ap.algorithm.scale_external_reward_by_intrinsic_reward_value:
                transition.reward = transition.reward * (
                    1 + self.last_action_info.action_intrinsic_reward)
            else:
                transition.reward = transition.reward + self.last_action_info.action_intrinsic_reward

            # sum up the total shaped reward
            self.total_shaped_reward_in_current_episode += transition.reward
            self.total_reward_in_current_episode += env_response.reward
            self.shaped_reward.add_sample(transition.reward)
            self.reward.add_sample(env_response.reward)

            # add action info to transition
            if type(self.parent).__name__ == 'CompositeAgent':
                transition.add_info(self.parent.last_action_info.__dict__)
            else:
                transition.add_info(self.last_action_info.__dict__)

            # create and store the transition
            if self.phase in [RunPhase.TRAIN, RunPhase.HEATUP]:
                # for episodic memories we keep the transitions in a local buffer until the episode is ended.
                # for regular memories we insert the transitions directly to the memory
                if isinstance(self.memory, EpisodicExperienceReplay):
                    self.current_episode_buffer.insert(transition)
                else:
                    self.call_memory('store', transition)

            if self.ap.visualization.dump_in_episode_signals:
                self.update_step_in_episode_log()

            return transition.game_over

    def post_training_commands(self):
        pass

    def get_predictions(self, states: List[Dict[str, np.ndarray]],
                        prediction_type: PredictionType):
        """
        Get a prediction from the agent with regard to the requested prediction_type.
        If the agent cannot predict this type of prediction_type, or if there is more than possible way to do so,
        raise a ValueException.
        :param states:
        :param prediction_type:
        :return:
        """

        predictions = self.networks[
            'main'].online_network.predict_with_prediction_type(
                # states=self.dict_state_to_batches_dict(states, 'main'), prediction_type=prediction_type)
                states=states,
                prediction_type=prediction_type)

        if len(predictions.keys()) != 1:
            raise ValueError(
                "The network has more than one component {} matching the requested prediction_type {}. "
                .format(list(predictions.keys()), prediction_type))
        return list(predictions.values())[0]

    def set_incoming_directive(self, action: ActionType) -> None:
        if isinstance(self.in_action_space, GoalsSpace):
            self.current_hrl_goal = action
        elif isinstance(self.in_action_space, AttentionActionSpace):
            self.input_filter.observation_filters[
                'attention'].crop_low = action[0]
            self.input_filter.observation_filters[
                'attention'].crop_high = action[1]
            self.output_filter.action_filters['masking'].set_masking(
                action[0], action[1])

    def save_checkpoint(self, checkpoint_id: int) -> None:
        """
        Allows agents to store additional information when saving checkpoints.
        :param checkpoint_id: the id of the checkpoint
        :return: None
        """
        pass

    def sync(self) -> None:
        """
        Sync the global network parameters to local networks
        :return: None
        """
        for network in self.networks.values():
            network.sync()
Ejemplo n.º 4
0
class GraphManager(object):
    """
    A graph manager is responsible for creating and initializing a graph of agents, including all its internal
    components. It is then used in order to schedule the execution of operations on the graph, such as acting and
    training.
    """
    def __init__(
        self,
        name: str,
        schedule_params: ScheduleParameters,
        vis_params: VisualizationParameters = VisualizationParameters()):
        self.sess = None
        self.level_managers = []
        self.top_level_manager = None
        self.environments = []
        self.heatup_steps = schedule_params.heatup_steps
        self.evaluation_steps = schedule_params.evaluation_steps
        self.steps_between_evaluation_periods = schedule_params.steps_between_evaluation_periods
        self.improve_steps = schedule_params.improve_steps
        self.visualization_parameters = vis_params
        self.name = name
        self.task_parameters = None
        self._phase = self.phase = RunPhase.UNDEFINED
        self.preset_validation_params = PresetValidationParameters()
        self.reset_required = False

        # timers
        self.graph_creation_time = None
        self.last_checkpoint_saving_time = time.time()

        # counters
        self.total_steps_counters = {
            RunPhase.HEATUP: TotalStepsCounter(),
            RunPhase.TRAIN: TotalStepsCounter(),
            RunPhase.TEST: TotalStepsCounter()
        }
        self.checkpoint_id = 0

        self.checkpoint_saver = None
        self.graph_logger = Logger()
        self.data_store = None

    def create_graph(self, task_parameters: TaskParameters = TaskParameters()):
        self.graph_creation_time = time.time()
        self.task_parameters = task_parameters

        if isinstance(task_parameters, DistributedTaskParameters):
            screen.log_title(
                "Creating graph - name: {} task id: {} type: {}".format(
                    self.__class__.__name__, task_parameters.task_index,
                    task_parameters.job_type))
        else:
            screen.log_title("Creating graph - name: {}".format(
                self.__class__.__name__))

        # "hide" the gpu if necessary
        if task_parameters.use_cpu:
            set_cpu()

        # create a target server for the worker and a device
        if isinstance(task_parameters, DistributedTaskParameters):
            task_parameters.worker_target, task_parameters.device = \
                self.create_worker_or_parameters_server(task_parameters=task_parameters)

        # create the graph modules
        self.level_managers, self.environments = self._create_graph(
            task_parameters)

        # set self as the parent of all the level managers
        self.top_level_manager = self.level_managers[0]
        for level_manager in self.level_managers:
            level_manager.parent_graph_manager = self

        # create a session (it needs to be created after all the graph ops were created)
        self.sess = None
        self.create_session(task_parameters=task_parameters)

        self._phase = self.phase = RunPhase.UNDEFINED

        self.setup_logger()

        return self

    def _create_graph(
        self, task_parameters: TaskParameters
    ) -> Tuple[List[LevelManager], List[Environment]]:
        """
        Create all the graph modules and the graph scheduler
        :param task_parameters: the parameters of the task
        :return: the initialized level managers and environments
        """
        raise NotImplementedError("")

    @staticmethod
    def _create_worker_or_parameters_server_tf(
            task_parameters: DistributedTaskParameters):
        import tensorflow as tf
        config = tf.ConfigProto()
        config.allow_soft_placement = True  # allow placing ops on cpu if they are not fit for gpu
        config.gpu_options.allow_growth = True  # allow the gpu memory allocated for the worker to grow if needed
        config.gpu_options.per_process_gpu_memory_fraction = 0.2
        config.intra_op_parallelism_threads = 1
        config.inter_op_parallelism_threads = 1

        from rl_coach.architectures.tensorflow_components.distributed_tf_utils import \
            create_and_start_parameters_server, \
            create_cluster_spec, create_worker_server_and_device

        # create cluster spec
        cluster_spec = create_cluster_spec(
            parameters_server=task_parameters.parameters_server_hosts,
            workers=task_parameters.worker_hosts)

        # create and start parameters server (non-returning function) or create a worker and a device setter
        if task_parameters.job_type == "ps":
            create_and_start_parameters_server(cluster_spec=cluster_spec,
                                               config=config)
        elif task_parameters.job_type == "worker":
            return create_worker_server_and_device(
                cluster_spec=cluster_spec,
                task_index=task_parameters.task_index,
                use_cpu=task_parameters.use_cpu,
                config=config)
        else:
            raise ValueError(
                "The job type should be either ps or worker and not {}".format(
                    task_parameters.job_type))

    @staticmethod
    def create_worker_or_parameters_server(
            task_parameters: DistributedTaskParameters):
        if task_parameters.framework_type == Frameworks.tensorflow:
            return GraphManager._create_worker_or_parameters_server_tf(
                task_parameters)
        elif task_parameters.framework_type == Frameworks.mxnet:
            raise NotImplementedError(
                'Distributed training not implemented for MXNet')
        else:
            raise ValueError('Invalid framework {}'.format(
                task_parameters.framework_type))

    def _create_session_tf(self, task_parameters: TaskParameters):
        import tensorflow as tf
        config = tf.ConfigProto()
        config.allow_soft_placement = True  # allow placing ops on cpu if they are not fit for gpu
        config.gpu_options.allow_growth = True  # allow the gpu memory allocated for the worker to grow if needed
        # config.gpu_options.per_process_gpu_memory_fraction = 0.2
        config.intra_op_parallelism_threads = 1
        config.inter_op_parallelism_threads = 1

        if isinstance(task_parameters, DistributedTaskParameters):
            # the distributed tensorflow setting
            from rl_coach.architectures.tensorflow_components.distributed_tf_utils import create_monitored_session
            if hasattr(self.task_parameters, 'checkpoint_restore_dir'
                       ) and self.task_parameters.checkpoint_restore_dir:
                checkpoint_dir = os.path.join(task_parameters.experiment_path,
                                              'checkpoint')
                if os.path.exists(checkpoint_dir):
                    remove_tree(checkpoint_dir)
                copy_tree(task_parameters.checkpoint_restore_dir,
                          checkpoint_dir)
            else:
                checkpoint_dir = task_parameters.checkpoint_save_dir

            self.sess = create_monitored_session(
                target=task_parameters.worker_target,
                task_index=task_parameters.task_index,
                checkpoint_dir=checkpoint_dir,
                checkpoint_save_secs=task_parameters.checkpoint_save_secs,
                config=config)
            # set the session for all the modules
            self.set_session(self.sess)
        else:
            self.variables_to_restore = tf.global_variables()
            # self.variables_to_restore = [v for v in self.variables_to_restore if '/online' in v.name] TODO: is this necessary?
            self.checkpoint_saver = tf.train.Saver(self.variables_to_restore)

            # regular session
            self.sess = tf.Session(config=config)

            # set the session for all the modules
            self.set_session(self.sess)

            # restore from checkpoint if given
            self.restore_checkpoint()

        # the TF graph is static, and therefore is saved once - in the beginning of the experiment
        if hasattr(self.task_parameters, 'checkpoint_save_dir'
                   ) and self.task_parameters.checkpoint_save_dir:
            self.save_graph()

    def create_session(self, task_parameters: TaskParameters):
        if task_parameters.framework_type == Frameworks.tensorflow:
            self._create_session_tf(task_parameters)
        elif task_parameters.framework_type == Frameworks.mxnet:
            self.set_session(sess=None)  # Initialize all modules
            # TODO add checkpoint loading
        else:
            raise ValueError('Invalid framework {}'.format(
                task_parameters.framework_type))

    def save_graph(self) -> None:
        """
        Save the TF graph to a protobuf description file in the experiment directory
        :return: None
        """
        import tensorflow as tf

        # write graph
        tf.train.write_graph(tf.get_default_graph(),
                             logdir=self.task_parameters.checkpoint_save_dir,
                             name='graphdef.pb',
                             as_text=False)

    def save_onnx_graph(self) -> None:
        """
        Save the graph as an ONNX graph.
        This requires the graph and the weights checkpoint to be stored in the experiment directory.
        It then freezes the graph (merging the graph and weights checkpoint), and converts it to ONNX.
        :return: None
        """

        # collect input and output nodes
        input_nodes = []
        output_nodes = []
        for level in self.level_managers:
            for agent in level.agents.values():
                for network in agent.networks.values():
                    for input_key, input in network.online_network.inputs.items(
                    ):
                        if not input_key.startswith("output_"):
                            input_nodes.append(input.name)
                    for output in network.online_network.outputs:
                        output_nodes.append(output.name)

        # TODO: make this framework agnostic
        from rl_coach.architectures.tensorflow_components.architecture import save_onnx_graph

        save_onnx_graph(input_nodes, output_nodes,
                        self.task_parameters.checkpoint_save_dir)

    def setup_logger(self) -> None:
        # dump documentation
        logger_prefix = "{graph_name}".format(graph_name=self.name)
        self.graph_logger.set_logger_filenames(
            self.task_parameters.experiment_path,
            logger_prefix=logger_prefix,
            add_timestamp=True,
            task_id=self.task_parameters.task_index)
        if self.visualization_parameters.dump_parameters_documentation:
            self.graph_logger.dump_documentation(str(self))
        [manager.setup_logger() for manager in self.level_managers]

    @property
    def phase(self) -> RunPhase:
        """
        Get the phase of the graph
        :return: the current phase
        """
        return self._phase

    @phase.setter
    def phase(self, val: RunPhase):
        """
        Change the phase of the graph and all the hierarchy levels below it
        :param val: the new phase
        :return: None
        """
        self._phase = val
        for level_manager in self.level_managers:
            level_manager.phase = val
        for environment in self.environments:
            environment.phase = val

    @property
    def current_step_counter(self) -> TotalStepsCounter:
        return self.total_steps_counters[self.phase]

    @contextlib.contextmanager
    def phase_context(self, phase):
        old_phase = self.phase
        self.phase = phase
        yield
        self.phase = old_phase

    def set_session(self, sess) -> None:
        """
        Set the deep learning framework session for all the modules in the graph
        :return: None
        """
        [manager.set_session(sess) for manager in self.level_managers]

    def heatup(self, steps: PlayingStepsType) -> None:
        """
        Perform heatup for several steps, which means taking random actions and storing the results in memory
        :param steps: the number of steps as a tuple of steps time and steps count
        :return: None
        """
        self.verify_graph_was_created()

        if steps.num_steps > 0:
            with self.phase_context(RunPhase.HEATUP):
                screen.log_title("{}: Starting heatup".format(self.name))

                # reset all the levels before starting to heatup
                self.reset_internal_state(force_environment_reset=True)

                # act for at least steps, though don't interrupt an episode
                count_end = self.current_step_counter + steps
                while self.current_step_counter < count_end:
                    self.act(EnvironmentEpisodes(1))

    def handle_episode_ended(self) -> None:
        """
        End an episode and reset all the episodic parameters
        :return: None
        """
        self.current_step_counter[EnvironmentEpisodes] += 1

        [
            environment.handle_episode_ended()
            for environment in self.environments
        ]

    def train(self) -> None:
        """
        Perform several training iterations for all the levels in the hierarchy
        :param steps: number of training iterations to perform
        :return: None
        """
        self.verify_graph_was_created()

        with self.phase_context(RunPhase.TRAIN):
            self.current_step_counter[TrainingSteps] += 1
            [manager.train() for manager in self.level_managers]

    def reset_internal_state(self, force_environment_reset=False) -> None:
        """
        Reset an episode for all the levels
        :param force_environment_reset: force the environment to reset the episode even if it has some conditions that
                                        tell it not to. for example, if ale life is lost, gym will tell the agent that
                                        the episode is finished but won't actually reset the episode if there are more
                                        lives available
        :return: None
        """
        self.verify_graph_was_created()

        self.reset_required = False
        [
            environment.reset_internal_state(force_environment_reset)
            for environment in self.environments
        ]
        [manager.reset_internal_state() for manager in self.level_managers]

    def act(self,
            steps: PlayingStepsType,
            wait_for_full_episodes=False) -> None:
        """
        Do several steps of acting on the environment
        :param wait_for_full_episodes: if set, act for at least `steps`, but make sure that the last episode is complete
        :param steps: the number of steps as a tuple of steps time and steps count
        """
        self.verify_graph_was_created()

        if hasattr(self, 'data_store_params') and hasattr(
                self.agent_params.memory, 'memory_backend_params'):
            if self.agent_params.memory.memory_backend_params.run_type == str(
                    RunType.ROLLOUT_WORKER):
                data_store = self.get_data_store(self.data_store_params)
                data_store.load_from_store()

        # perform several steps of playing
        count_end = self.current_step_counter + steps
        result = None
        while self.current_step_counter < count_end or (
                wait_for_full_episodes and result is not None
                and not result.game_over):
            # reset the environment if the previous episode was terminated
            if self.reset_required:
                self.reset_internal_state()

            steps_begin = self.environments[0].total_steps_counter
            result = self.top_level_manager.step(None)
            steps_end = self.environments[0].total_steps_counter

            # add the diff between the total steps before and after stepping, such that environment initialization steps
            # (like in Atari) will not be counted.
            # We add at least one step so that even if no steps were made (in case no actions are taken in the training
            # phase), the loop will end eventually.
            self.current_step_counter[EnvironmentSteps] += max(
                1, steps_end - steps_begin)

            if result.game_over:
                self.handle_episode_ended()
                self.reset_required = True

    def train_and_act(self, steps: StepMethod) -> None:
        """
        Train the agent by doing several acting steps followed by several training steps continually
        :param steps: the number of steps as a tuple of steps time and steps count
        :return: None
        """
        self.verify_graph_was_created()

        # perform several steps of training interleaved with acting
        if steps.num_steps > 0:
            with self.phase_context(RunPhase.TRAIN):
                self.reset_internal_state(force_environment_reset=True)

                count_end = self.current_step_counter + steps
                while self.current_step_counter < count_end:
                    # The actual steps being done on the environment are decided by the agents themselves.
                    # This is just an high-level controller.
                    self.act(EnvironmentSteps(1))
                    self.train()
                    self.occasionally_save_checkpoint()

    def sync(self) -> None:
        """
        Sync the global network parameters to the graph
        :return:
        """
        [manager.sync() for manager in self.level_managers]

    def evaluate(self,
                 steps: PlayingStepsType,
                 keep_networks_in_sync: bool = False) -> bool:
        """
        Perform evaluation for several steps
        :param steps: the number of steps as a tuple of steps time and steps count
        :param keep_networks_in_sync: sync the network parameters with the global network before each episode
        :return: bool, True if the target reward and target success has been reached
        """
        self.verify_graph_was_created()

        if steps.num_steps > 0:
            with self.phase_context(RunPhase.TEST):
                # reset all the levels before starting to evaluate
                self.reset_internal_state(force_environment_reset=True)
                self.sync()

                # act for at least `steps`, though don't interrupt an episode
                count_end = self.current_step_counter + steps
                while self.current_step_counter < count_end:
                    self.act(EnvironmentEpisodes(1))
                    self.sync()
        if self.should_stop():
            if self.task_parameters.checkpoint_save_dir:
                open(
                    os.path.join(self.task_parameters.checkpoint_save_dir,
                                 SyncFiles.FINISHED.value), 'w').close()
            if hasattr(self, 'data_store_params'):
                data_store = self.get_data_store(self.data_store_params)
                data_store.save_to_store()

            screen.success("Reached required success rate. Exiting.")
            return True
        return False

    def improve(self):
        """
        The main loop of the run.
        Defined in the following steps:
        1. Heatup
        2. Repeat:
            2.1. Repeat:
                2.1.1. Act
                2.1.2. Train
                2.1.3. Possibly save checkpoint
            2.2. Evaluate
        :return: None
        """

        self.verify_graph_was_created()

        # initialize the network parameters from the global network
        self.sync()

        # heatup
        self.heatup(self.heatup_steps)

        # improve
        if self.task_parameters.task_index is not None:
            screen.log_title("Starting to improve {} task index {}".format(
                self.name, self.task_parameters.task_index))
        else:
            screen.log_title("Starting to improve {}".format(self.name))

        count_end = self.total_steps_counters[
            RunPhase.TRAIN] + self.improve_steps
        while self.total_steps_counters[RunPhase.TRAIN] < count_end:
            self.train_and_act(self.steps_between_evaluation_periods)
            if self.evaluate(self.evaluation_steps):
                break

    def _restore_checkpoint_tf(self, checkpoint_dir: str):
        import tensorflow as tf
        checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
        screen.log_title("Loading checkpoint: {}".format(
            checkpoint.model_checkpoint_path))
        variables = {}
        for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):
            # Load the variable
            var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)

            # Set the new name
            new_name = var_name
            new_name = new_name.replace('global/', 'online/')
            variables[new_name] = var

        for v in self.variables_to_restore:
            self.sess.run(v.assign(variables[v.name.split(':')[0]]))

    def restore_checkpoint(self):
        self.verify_graph_was_created()

        # TODO: find better way to load checkpoints that were saved with a global network into the online network
        if hasattr(self.task_parameters, 'checkpoint_restore_dir'
                   ) and self.task_parameters.checkpoint_restore_dir:
            if self.task_parameters.framework_type == Frameworks.tensorflow:
                self._restore_checkpoint_tf(
                    self.task_parameters.checkpoint_restore_dir)
            elif self.task_parameters.framework_type == Frameworks.mxnet:
                # TODO implement checkpoint restore
                pass
            else:
                raise ValueError('Invalid framework {}'.format(
                    self.task_parameters.framework_type))

    def occasionally_save_checkpoint(self):
        # only the chief process saves checkpoints
        if self.task_parameters.checkpoint_save_secs \
                and time.time() - self.last_checkpoint_saving_time >= self.task_parameters.checkpoint_save_secs \
                and (self.task_parameters.task_index == 0  # distributed
                     or self.task_parameters.task_index is None  # single-worker
                     ):
            self.save_checkpoint()

    def save_checkpoint(self):
        if self.task_parameters.checkpoint_save_dir is None:
            self.task_parameters.checkpoint_save_dir = os.path.join(
                self.task_parameters.experiment_path, 'checkpoint')
        checkpoint_path = os.path.join(
            self.task_parameters.checkpoint_save_dir, "{}_Step-{}.ckpt".format(
                self.checkpoint_id,
                self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps]))
        if not isinstance(self.task_parameters, DistributedTaskParameters):
            if self.checkpoint_saver is not None:
                saved_checkpoint_path = self.checkpoint_saver.save(
                    self.sess, checkpoint_path)
            else:
                saved_checkpoint_path = "<Not Saved>"
        else:
            saved_checkpoint_path = checkpoint_path

        # this is required in order for agents to save additional information like a DND for example
        [
            manager.save_checkpoint(self.checkpoint_id)
            for manager in self.level_managers
        ]

        # the ONNX graph will be stored only if checkpoints are stored and the -onnx flag is used
        if self.task_parameters.export_onnx_graph:
            self.save_onnx_graph()

        screen.log_dict(OrderedDict([
            ("Saving in path", saved_checkpoint_path),
        ]),
                        prefix="Checkpoint")

        self.checkpoint_id += 1
        self.last_checkpoint_saving_time = time.time()

        if hasattr(self, 'data_store_params'):
            data_store = self.get_data_store(self.data_store_params)
            data_store.save_to_store()

    def verify_graph_was_created(self):
        """
        Verifies that the graph was already created, and if not, it creates it with the default task parameters
        :return: None
        """
        if self.graph_creation_time is None:
            self.create_graph()

    def __str__(self):
        result = ""
        for key, val in self.__dict__.items():
            params = ""
            if isinstance(val, list) or isinstance(val, dict) or isinstance(
                    val, OrderedDict):
                items = iterable_to_items(val)
                for k, v in items:
                    params += "{}: {}\n".format(k, v)
            else:
                params = val
            result += "{}: \n{}\n".format(key, params)

        return result

    def should_train(self) -> bool:
        return any([manager.should_train() for manager in self.level_managers])

    # TODO-remove - this is a temporary flow, used by the trainer worker, duplicated from observe() - need to create
    #               an external trainer flow reusing the existing flow and methods [e.g. observe(), step(), act()]
    def emulate_act_on_trainer(self, steps: PlayingStepsType,
                               transition: Transition) -> None:
        """
        This emulates the act using the transition obtained from the rollout worker on the training worker
        in case of distributed training.
        Do several steps of acting on the environment
        :param steps: the number of steps as a tuple of steps time and steps count
        """
        self.verify_graph_was_created()

        # perform several steps of playing
        count_end = self.current_step_counter + steps
        while self.current_step_counter < count_end:
            # reset the environment if the previous episode was terminated
            if self.reset_required:
                self.reset_internal_state()

            steps_begin = self.environments[0].total_steps_counter
            self.top_level_manager.emulate_step_on_trainer(transition)
            steps_end = self.environments[0].total_steps_counter

            # add the diff between the total steps before and after stepping, such that environment initialization steps
            # (like in Atari) will not be counted.
            # We add at least one step so that even if no steps were made (in case no actions are taken in the training
            # phase), the loop will end eventually.
            self.current_step_counter[EnvironmentSteps] += max(
                1, steps_end - steps_begin)

            if transition.game_over:
                self.handle_episode_ended()
                self.reset_required = True

    def fetch_from_worker(self, num_consecutive_playing_steps=None):
        if hasattr(self, 'memory_backend'):
            for transition in self.memory_backend.fetch(
                    num_consecutive_playing_steps):
                self.emulate_act_on_trainer(EnvironmentSteps(1), transition)

    def setup_memory_backend(self) -> None:
        if hasattr(self.agent_params.memory, 'memory_backend_params'):
            self.memory_backend = get_memory_backend(
                self.agent_params.memory.memory_backend_params)

    def should_stop(self) -> bool:
        return all([manager.should_stop() for manager in self.level_managers])

    def get_data_store(self, param):
        if self.data_store:
            return self.data_store

        return data_store_creator(param)