Example #1
0
    def restore_checkpoint(self):
        self.verify_graph_was_created()

        # TODO: find better way to load checkpoints that were saved with a global network into the online network
        if self.task_parameters.checkpoint_restore_path:
            restored_checkpoint_paths = []
            for agent_params in self.agents_params:
                if len(self.agents_params) == 1:
                    agent_checkpoint_restore_path = self.task_parameters.checkpoint_restore_path
                else:
                    agent_checkpoint_restore_path = os.path.join(
                        self.task_parameters.checkpoint_restore_path,
                        agent_params.name)
                if os.path.isdir(agent_checkpoint_restore_path):
                    # a checkpoint dir
                    if self.task_parameters.framework_type == Frameworks.tensorflow and\
                            'checkpoint' in os.listdir(agent_checkpoint_restore_path):
                        # TODO-fixme checkpointing
                        # MonitoredTrainingSession manages save/restore checkpoints autonomously. Doing so,
                        # it creates it own names for the saved checkpoints, which do not match the "{}_Step-{}.ckpt"
                        # filename pattern. The names used are maintained in a CheckpointState protobuf file named
                        # 'checkpoint'. Using Coach's '.coach_checkpoint' protobuf file, results in an error when trying to
                        # restore the model, as the checkpoint names defined do not match the actual checkpoint names.
                        raise NotImplementedError(
                            'Checkpointing not implemented for TF monitored training session'
                        )
                    else:
                        checkpoint = get_checkpoint_state(
                            agent_checkpoint_restore_path,
                            all_checkpoints=True)

                    if checkpoint is None:
                        raise ValueError(
                            "No checkpoint to restore in: {}".format(
                                agent_checkpoint_restore_path))
                    model_checkpoint_path = checkpoint.model_checkpoint_path
                    checkpoint_restore_dir = self.task_parameters.checkpoint_restore_path
                    restored_checkpoint_paths.append(model_checkpoint_path)

                    # Set the last checkpoint ID - only in the case of the path being a dir
                    chkpt_state_reader = CheckpointStateReader(
                        agent_checkpoint_restore_path,
                        checkpoint_state_optional=False)
                    self.checkpoint_id = chkpt_state_reader.get_latest(
                    ).num + 1
                else:
                    # a checkpoint file
                    if self.task_parameters.framework_type == Frameworks.tensorflow:
                        model_checkpoint_path = agent_checkpoint_restore_path
                        checkpoint_restore_dir = os.path.dirname(
                            model_checkpoint_path)
                        restored_checkpoint_paths.append(model_checkpoint_path)
                    else:
                        raise ValueError(
                            "Currently restoring a checkpoint using the --checkpoint_restore_file argument is"
                            " only supported when with tensorflow.")

                try:
                    self.checkpoint_saver[agent_params.name].restore(
                        self.sess[agent_params.name], model_checkpoint_path)
                except Exception as ex:
                    raise ValueError(
                        "Failed to restore {}'s checkpoint: {}".format(
                            agent_params.name, ex))

                all_checkpoints = sorted(
                    list(set([c.name for c in checkpoint.all_checkpoints
                              ])))  # remove duplicates :-(
                if self.num_checkpoints_to_keep < len(all_checkpoints):
                    checkpoint_to_delete = all_checkpoints[
                        -self.num_checkpoints_to_keep - 1]
                    agent_checkpoint_to_delete = os.path.join(
                        agent_checkpoint_restore_path, checkpoint_to_delete)
                    for file in glob.glob(
                            "{}*".format(agent_checkpoint_to_delete)):
                        os.remove(file)

            [
                manager.restore_checkpoint(checkpoint_restore_dir)
                for manager in self.level_managers
            ]
            [
                manager.post_training_commands()
                for manager in self.level_managers
            ]

            screen.log_dict(OrderedDict([
                ("Restoring from path", restored_checkpoint_path)
                for restored_checkpoint_path in restored_checkpoint_paths
            ]),
                            prefix="Checkpoint")
Example #2
0
def rollout_worker(graph_manager, num_workers, rollout_idx, task_parameters,
                   simtrace_video_s3_writers):
    """
    wait for first checkpoint then perform rollouts using the model
    """
    if not graph_manager.data_store:
        raise AttributeError("None type for data_store object")

    data_store = graph_manager.data_store

    #TODO change agent to specific agent name for multip agent case
    checkpoint_dir = os.path.join(task_parameters.checkpoint_restore_path,
                                  "agent")
    graph_manager.data_store.wait_for_checkpoints()
    graph_manager.data_store.wait_for_trainer_ready()
    # Make the clients that will allow us to pause and unpause the physics
    rospy.wait_for_service('/gazebo/pause_physics_dr')
    rospy.wait_for_service('/gazebo/unpause_physics_dr')
    rospy.wait_for_service('/racecar/save_mp4/subscribe_to_save_mp4')
    rospy.wait_for_service('/racecar/save_mp4/unsubscribe_from_save_mp4')
    pause_physics = ServiceProxyWrapper('/gazebo/pause_physics_dr', Empty)
    unpause_physics = ServiceProxyWrapper('/gazebo/unpause_physics_dr', Empty)
    subscribe_to_save_mp4 = ServiceProxyWrapper(
        '/racecar/save_mp4/subscribe_to_save_mp4', Empty)
    unsubscribe_from_save_mp4 = ServiceProxyWrapper(
        '/racecar/save_mp4/unsubscribe_from_save_mp4', Empty)
    graph_manager.create_graph(task_parameters=task_parameters,
                               stop_physics=pause_physics,
                               start_physics=unpause_physics,
                               empty_service_call=EmptyRequest)

    chkpt_state_reader = CheckpointStateReader(checkpoint_dir,
                                               checkpoint_state_optional=False)
    last_checkpoint = chkpt_state_reader.get_latest().num

    # this worker should play a fraction of the total playing steps per rollout
    episode_steps_per_rollout = graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps
    act_steps = int(episode_steps_per_rollout / num_workers)
    if rollout_idx < episode_steps_per_rollout % num_workers:
        act_steps += 1
    act_steps = EnvironmentEpisodes(act_steps)

    configure_environment_randomizer()

    for _ in range(
        (graph_manager.improve_steps / act_steps.num_steps).num_steps):
        # Collect profiler information only IS_PROFILER_ON is true
        with utils.Profiler(s3_bucket=PROFILER_S3_BUCKET,
                            s3_prefix=PROFILER_S3_PREFIX,
                            output_local_path=ROLLOUT_WORKER_PROFILER_PATH,
                            enable_profiling=IS_PROFILER_ON):
            graph_manager.phase = RunPhase.TRAIN
            exit_if_trainer_done(checkpoint_dir, simtrace_video_s3_writers,
                                 rollout_idx)
            unpause_physics(EmptyRequest())
            graph_manager.reset_internal_state(True)
            graph_manager.act(act_steps,
                              wait_for_full_episodes=graph_manager.
                              agent_params.algorithm.act_for_full_episodes)
            graph_manager.reset_internal_state(True)
            time.sleep(1)
            pause_physics(EmptyRequest())

            graph_manager.phase = RunPhase.UNDEFINED
            new_checkpoint = -1
            if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type\
                    == DistributedCoachSynchronizationType.SYNC:
                unpause_physics(EmptyRequest())
                is_save_mp4_enabled = rospy.get_param(
                    'MP4_S3_BUCKET', None) and rollout_idx == 0
                if is_save_mp4_enabled:
                    subscribe_to_save_mp4(EmptyRequest())
                if rollout_idx == 0:
                    for _ in range(int(rospy.get_param('MIN_EVAL_TRIALS',
                                                       '5'))):
                        graph_manager.evaluate(EnvironmentSteps(1))

                while new_checkpoint < last_checkpoint + 1:
                    exit_if_trainer_done(checkpoint_dir,
                                         simtrace_video_s3_writers,
                                         rollout_idx)
                    if rollout_idx == 0:
                        print(
                            "Additional evaluation. New Checkpoint: {}, Last Checkpoint: {}"
                            .format(new_checkpoint, last_checkpoint))
                        graph_manager.evaluate(EnvironmentSteps(1))
                    else:
                        time.sleep(5)
                    new_checkpoint = data_store.get_coach_checkpoint_number(
                        'agent')
                if is_save_mp4_enabled:
                    unsubscribe_from_save_mp4(EmptyRequest())
                logger.info(
                    "Completed iteration tasks. Writing results to S3.")
                # upload simtrace and mp4 into s3 bucket
                for s3_writer in simtrace_video_s3_writers:
                    s3_writer.persist(utils.get_s3_kms_extra_args())
                pause_physics(EmptyRequest())
                logger.info(
                    "Preparing to load checkpoint {}".format(last_checkpoint +
                                                             1))
                data_store.load_from_store(
                    expected_checkpoint_number=last_checkpoint + 1)
                graph_manager.restore_checkpoint()

            if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type\
                    == DistributedCoachSynchronizationType.ASYNC:
                if new_checkpoint > last_checkpoint:
                    graph_manager.restore_checkpoint()

            last_checkpoint = new_checkpoint

    logger.info("Exited main loop. Done.")
def rollout_worker(graph_manager, num_workers, rollout_idx, task_parameters,
                   s3_writer):
    """
    wait for first checkpoint then perform rollouts using the model
    """
    if not graph_manager.data_store:
        raise AttributeError("None type for data_store object")

    data_store = graph_manager.data_store

    checkpoint_dir = task_parameters.checkpoint_restore_path
    wait_for_checkpoint(checkpoint_dir, data_store)
    wait_for_trainer_ready(checkpoint_dir, data_store)
    # Make the clients that will allow us to pause and unpause the physics
    rospy.wait_for_service('/gazebo/pause_physics')
    rospy.wait_for_service('/gazebo/unpause_physics')
    rospy.wait_for_service('/racecar/save_mp4/subscribe_to_save_mp4')
    rospy.wait_for_service('/racecar/save_mp4/unsubscribe_from_save_mp4')
    pause_physics = ServiceProxyWrapper('/gazebo/pause_physics', Empty)
    unpause_physics = ServiceProxyWrapper('/gazebo/unpause_physics', Empty)
    subscribe_to_save_mp4 = ServiceProxyWrapper(
        '/racecar/save_mp4/subscribe_to_save_mp4', Empty)
    unsubscribe_from_save_mp4 = ServiceProxyWrapper(
        '/racecar/save_mp4/unsubscribe_from_save_mp4', Empty)
    graph_manager.create_graph(task_parameters=task_parameters,
                               stop_physics=pause_physics,
                               start_physics=unpause_physics,
                               empty_service_call=EmptyRequest)

    chkpt_state_reader = CheckpointStateReader(checkpoint_dir,
                                               checkpoint_state_optional=False)
    last_checkpoint = chkpt_state_reader.get_latest().num

    # this worker should play a fraction of the total playing steps per rollout
    episode_steps_per_rollout = graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps
    act_steps = int(episode_steps_per_rollout / num_workers)
    if rollout_idx < episode_steps_per_rollout % num_workers:
        act_steps += 1
    act_steps = EnvironmentEpisodes(act_steps)

    configure_environment_randomizer()

    for _ in range(
        (graph_manager.improve_steps / act_steps.num_steps).num_steps):
        # Collect profiler information only IS_PROFILER_ON is true
        with utils.Profiler(s3_bucket=PROFILER_S3_BUCKET,
                            s3_prefix=PROFILER_S3_PREFIX,
                            output_local_path=ROLLOUT_WORKER_PROFILER_PATH,
                            enable_profiling=IS_PROFILER_ON):
            graph_manager.phase = RunPhase.TRAIN
            exit_if_trainer_done(checkpoint_dir, s3_writer, rollout_idx)
            unpause_physics(EmptyRequest())
            graph_manager.reset_internal_state(True)
            graph_manager.act(act_steps,
                              wait_for_full_episodes=graph_manager.
                              agent_params.algorithm.act_for_full_episodes)
            graph_manager.reset_internal_state(True)
            time.sleep(1)
            pause_physics(EmptyRequest())

            graph_manager.phase = RunPhase.UNDEFINED
            new_checkpoint = -1
            if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type\
                    == DistributedCoachSynchronizationType.SYNC:
                unpause_physics(EmptyRequest())
                is_save_mp4_enabled = rospy.get_param(
                    'MP4_S3_BUCKET', None) and rollout_idx == 0
                if is_save_mp4_enabled:
                    subscribe_to_save_mp4(EmptyRequest())
                if rollout_idx == 0:
                    for _ in range(MIN_EVAL_TRIALS):
                        graph_manager.evaluate(EnvironmentSteps(1))

                while new_checkpoint < last_checkpoint + 1:
                    exit_if_trainer_done(checkpoint_dir, s3_writer,
                                         rollout_idx)
                    if rollout_idx == 0:
                        graph_manager.evaluate(EnvironmentSteps(1))
                    new_checkpoint = data_store.get_chkpoint_num('agent')
                if is_save_mp4_enabled:
                    unsubscribe_from_save_mp4(EmptyRequest())
                s3_writer.upload_to_s3()

                pause_physics(EmptyRequest())
                data_store.load_from_store(
                    expected_checkpoint_number=last_checkpoint + 1)
                graph_manager.restore_checkpoint()

            if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type\
                    == DistributedCoachSynchronizationType.ASYNC:
                if new_checkpoint > last_checkpoint:
                    graph_manager.restore_checkpoint()

            last_checkpoint = new_checkpoint
Example #4
0
def rollout_worker(graph_manager, num_workers, task_parameters, s3_writer):
    """
    wait for first checkpoint then perform rollouts using the model
    """
    if not graph_manager.data_store:
        raise AttributeError("None type for data_store object")

    data_store = graph_manager.data_store

    checkpoint_dir = task_parameters.checkpoint_restore_path
    wait_for_checkpoint(checkpoint_dir, data_store)
    wait_for_trainer_ready(checkpoint_dir, data_store)
    # Make the clients that will allow us to pause and unpause the physics
    rospy.wait_for_service('/gazebo/pause_physics')
    rospy.wait_for_service('/gazebo/unpause_physics')
    rospy.wait_for_service('/racecar/save_mp4/subscribe_to_save_mp4')
    rospy.wait_for_service('/racecar/save_mp4/unsubscribe_from_save_mp4')
    pause_physics = ServiceProxyWrapper('/gazebo/pause_physics', Empty)
    unpause_physics = ServiceProxyWrapper('/gazebo/unpause_physics', Empty)
    subscribe_to_save_mp4 = ServiceProxyWrapper('/racecar/save_mp4/subscribe_to_save_mp4', Empty)
    unsubscribe_from_save_mp4 = ServiceProxyWrapper('/racecar/save_mp4/unsubscribe_from_save_mp4', Empty)
    graph_manager.create_graph(task_parameters=task_parameters, stop_physics=pause_physics,
                               start_physics=unpause_physics, empty_service_call=EmptyRequest)

    with graph_manager.phase_context(RunPhase.TRAIN):
        chkpt_state_reader = CheckpointStateReader(checkpoint_dir, checkpoint_state_optional=False)
        last_checkpoint = chkpt_state_reader.get_latest().num

        for level in graph_manager.level_managers:
            for agent in level.agents.values():
                agent.memory.memory_backend.set_current_checkpoint(last_checkpoint)

        # this worker should play a fraction of the total playing steps per rollout
        act_steps = 1
        while True:
            graph_manager.phase = RunPhase.TRAIN
            exit_if_trainer_done(checkpoint_dir, s3_writer)
            unpause_physics(EmptyRequest())
            graph_manager.reset_internal_state(True)
            graph_manager.act(EnvironmentSteps(num_steps=act_steps), wait_for_full_episodes=graph_manager.agent_params.algorithm.act_for_full_episodes)
            graph_manager.reset_internal_state(True)
            time.sleep(1)
            pause_physics(EmptyRequest())

            graph_manager.phase = RunPhase.UNDEFINED
            new_checkpoint = data_store.get_chkpoint_num('agent')
            if new_checkpoint and new_checkpoint > last_checkpoint:
                if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type\
                        == DistributedCoachSynchronizationType.SYNC:
                    exit_if_trainer_done(checkpoint_dir, s3_writer)
                    unpause_physics(EmptyRequest())
                    is_save_mp4_enabled = rospy.get_param('MP4_S3_BUCKET', None)
                    if is_save_mp4_enabled:
                        subscribe_to_save_mp4(EmptyRequest())
                    for _ in range(MIN_EVAL_TRIALS):
                        graph_manager.evaluate(EnvironmentSteps(1))
                    if is_save_mp4_enabled:
                        unsubscribe_from_save_mp4(EmptyRequest())
                    s3_writer.upload_to_s3()

                    pause_physics(EmptyRequest())
                if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type\
                        == DistributedCoachSynchronizationType.ASYNC:
                    graph_manager.restore_checkpoint()

                last_checkpoint = new_checkpoint
                for level in graph_manager.level_managers:
                    for agent in level.agents.values():
                        agent.memory.memory_backend.set_current_checkpoint(last_checkpoint)