Esempio n. 1
0
def training_worker(graph_manager, checkpoint_dir, use_pretrained_model,
                    framework):
    """
    restore a checkpoint then perform rollouts using the restored model
    """
    # initialize graph
    task_parameters = TaskParameters()
    task_parameters.__dict__['checkpoint_save_dir'] = checkpoint_dir
    task_parameters.__dict__['checkpoint_save_secs'] = 20
    task_parameters.__dict__['experiment_path'] = SM_MODEL_OUTPUT_DIR

    if framework.lower() == "mxnet":
        task_parameters.framework_type = Frameworks.mxnet
        if hasattr(graph_manager, 'agent_params'):
            for network_parameters in graph_manager.agent_params.network_wrappers.values(
            ):
                network_parameters.framework = Frameworks.mxnet
        elif hasattr(graph_manager, 'agents_params'):
            for ap in graph_manager.agents_params:
                for network_parameters in ap.network_wrappers.values():
                    network_parameters.framework = Frameworks.mxnet

    if use_pretrained_model:
        task_parameters.__dict__[
            'checkpoint_restore_dir'] = PRETRAINED_MODEL_DIR

    graph_manager.create_graph(task_parameters)

    # save randomly initialized graph
    graph_manager.save_checkpoint()

    # training loop
    steps = 0
    graph_manager.setup_memory_backend()

    # To handle SIGTERM
    door_man = DoorMan()

    try:
        while (steps < graph_manager.improve_steps.num_steps):
            graph_manager.phase = core_types.RunPhase.TRAIN
            graph_manager.fetch_from_worker(
                graph_manager.agent_params.algorithm.
                num_consecutive_playing_steps)
            graph_manager.phase = core_types.RunPhase.UNDEFINED

            if graph_manager.should_train():
                steps += graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps

                graph_manager.phase = core_types.RunPhase.TRAIN
                graph_manager.train()
                graph_manager.phase = core_types.RunPhase.UNDEFINED

                if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
                    graph_manager.save_checkpoint()
                else:
                    graph_manager.occasionally_save_checkpoint()

            if door_man.terminate_now:
                "Received SIGTERM. Checkpointing before exiting."
                graph_manager.save_checkpoint()
                break

    except Exception as e:
        raise RuntimeError("An error occured while training: %s" % e)
    finally:
        print("Terminating training worker")
        graph_manager.data_store.upload_finished_file()
def training_worker(graph_manager, checkpoint_dir, use_pretrained_model,
                    framework, memory_backend_params, user_batch_size,
                    user_episode_per_rollout):
    """
    restore a checkpoint then perform rollouts using the restored model
    """
    # initialize graph
    task_parameters = TaskParameters()
    task_parameters.__dict__['checkpoint_save_dir'] = checkpoint_dir
    task_parameters.__dict__['checkpoint_save_secs'] = 20
    task_parameters.__dict__['experiment_path'] = SM_MODEL_OUTPUT_DIR

    if framework.lower() == "mxnet":
        task_parameters.framework_type = Frameworks.mxnet
        if hasattr(graph_manager, 'agent_params'):
            for network_parameters in graph_manager.agent_params.network_wrappers.values(
            ):
                network_parameters.framework = Frameworks.mxnet
        elif hasattr(graph_manager, 'agents_params'):
            for ap in graph_manager.agents_params:
                for network_parameters in ap.network_wrappers.values():
                    network_parameters.framework = Frameworks.mxnet

    if use_pretrained_model:
        task_parameters.__dict__[
            'checkpoint_restore_dir'] = PRETRAINED_MODEL_DIR

    graph_manager.create_graph(task_parameters)

    # save randomly initialized graph
    graph_manager.save_checkpoint()

    # training loop
    steps = 0

    graph_manager.memory_backend = deepracer_memory.DeepRacerTrainerBackEnd(
        memory_backend_params)

    # To handle SIGTERM
    door_man = DoorMan()

    try:
        while steps < graph_manager.improve_steps.num_steps:
            graph_manager.phase = RunPhase.TRAIN
            graph_manager.fetch_from_worker(
                graph_manager.agent_params.algorithm.
                num_consecutive_playing_steps)
            graph_manager.phase = RunPhase.UNDEFINED

            episodes_in_rollout = graph_manager.memory_backend.get_total_episodes_in_rollout(
            )

            for level in graph_manager.level_managers:
                for agent in level.agents.values():
                    agent.ap.algorithm.num_consecutive_playing_steps.num_steps = episodes_in_rollout
                    agent.ap.algorithm.num_steps_between_copying_online_weights_to_target.num_steps = episodes_in_rollout

            if graph_manager.should_train():
                # Make sure we have enough data for the requested batches
                rollout_steps = graph_manager.memory_backend.get_rollout_steps(
                )
                if rollout_steps <= 0:
                    utils.json_format_logger(
                        "No rollout data retrieved from the rollout worker",
                        **utils.build_system_error_dict(
                            utils.SIMAPP_TRAINING_WORKER_EXCEPTION,
                            utils.SIMAPP_EVENT_ERROR_CODE_500))
                    utils.simapp_exit_gracefully()

                episode_batch_size = user_batch_size if rollout_steps > user_batch_size else 2**math.floor(
                    math.log(rollout_steps, 2))
                # Set the batch size to the closest power of 2 such that we have at least two batches, this prevents coach from crashing
                # as  batch size less than 2 causes the batch list to become a scalar which causes an exception
                for level in graph_manager.level_managers:
                    for agent in level.agents.values():
                        agent.ap.network_wrappers[
                            'main'].batch_size = episode_batch_size

                steps += 1

                graph_manager.phase = RunPhase.TRAIN
                graph_manager.train()
                graph_manager.phase = RunPhase.UNDEFINED

                # Check for Nan's in all agents
                rollout_has_nan = False
                for level in graph_manager.level_managers:
                    for agent in level.agents.values():
                        if np.isnan(agent.loss.get_mean()):
                            rollout_has_nan = True
                #! TODO handle NaN's on a per agent level for distributed training
                if rollout_has_nan:
                    utils.json_format_logger(
                        "NaN detected in loss function, aborting training.",
                        **utils.build_system_error_dict(
                            utils.SIMAPP_TRAINING_WORKER_EXCEPTION,
                            utils.SIMAPP_EVENT_ERROR_CODE_500))
                    utils.simapp_exit_gracefully()

                if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
                    graph_manager.save_checkpoint()
                else:
                    graph_manager.occasionally_save_checkpoint()
                # Clear any data stored in signals that is no longer necessary
                graph_manager.reset_internal_state()

            for level in graph_manager.level_managers:
                for agent in level.agents.values():
                    agent.ap.algorithm.num_consecutive_playing_steps.num_steps = user_episode_per_rollout
                    agent.ap.algorithm.num_steps_between_copying_online_weights_to_target.num_steps = user_episode_per_rollout

            if door_man.terminate_now:
                utils.json_format_logger(
                    "Received SIGTERM. Checkpointing before exiting.",
                    **utils.build_system_error_dict(
                        utils.SIMAPP_TRAINING_WORKER_EXCEPTION,
                        utils.SIMAPP_EVENT_ERROR_CODE_500))
                graph_manager.save_checkpoint()
                break

    except Exception as e:
        utils.json_format_logger(
            "An error occured while training: {}.".format(e),
            **utils.build_system_error_dict(
                utils.SIMAPP_TRAINING_WORKER_EXCEPTION,
                utils.SIMAPP_EVENT_ERROR_CODE_500))
        traceback.print_exc()
        utils.simapp_exit_gracefully()
    finally:
        graph_manager.data_store.upload_finished_file()
Esempio n. 3
0
def training_worker(graph_manager, checkpoint_dir, use_pretrained_model,
                    framework, memory_backend_params):
    """
    restore a checkpoint then perform rollouts using the restored model
    """
    # initialize graph
    task_parameters = TaskParameters()
    task_parameters.__dict__['checkpoint_save_dir'] = checkpoint_dir
    task_parameters.__dict__['checkpoint_save_secs'] = 20
    task_parameters.__dict__['experiment_path'] = SM_MODEL_OUTPUT_DIR

    if framework.lower() == "mxnet":
        task_parameters.framework_type = Frameworks.mxnet
        if hasattr(graph_manager, 'agent_params'):
            for network_parameters in graph_manager.agent_params.network_wrappers.values(
            ):
                network_parameters.framework = Frameworks.mxnet
        elif hasattr(graph_manager, 'agents_params'):
            for ap in graph_manager.agents_params:
                for network_parameters in ap.network_wrappers.values():
                    network_parameters.framework = Frameworks.mxnet

    if use_pretrained_model:
        task_parameters.__dict__[
            'checkpoint_restore_dir'] = PRETRAINED_MODEL_DIR

    graph_manager.create_graph(task_parameters)

    # save randomly initialized graph
    graph_manager.save_checkpoint()

    # training loop
    steps = 0

    graph_manager.memory_backend = deepracer_memory.DeepRacerTrainerBackEnd(
        memory_backend_params)

    # To handle SIGTERM
    door_man = DoorMan()

    try:
        while steps < graph_manager.improve_steps.num_steps:
            graph_manager.phase = core_types.RunPhase.TRAIN
            graph_manager.fetch_from_worker(
                graph_manager.agent_params.algorithm.
                num_consecutive_playing_steps)
            graph_manager.phase = core_types.RunPhase.UNDEFINED

            if graph_manager.should_train():
                steps += 1

                graph_manager.phase = core_types.RunPhase.TRAIN
                graph_manager.train()
                graph_manager.phase = core_types.RunPhase.UNDEFINED

                # Check for Nan's in all agents
                rollout_has_nan = False
                for level in graph_manager.level_managers:
                    for agent in level.agents.values():
                        if np.isnan(agent.loss.get_mean()):
                            rollout_has_nan = True
                #! TODO handle NaN's on a per agent level for distributed training
                if rollout_has_nan:
                    utils.json_format_logger(
                        "NaN detected in loss function, aborting training. Job failed!",
                        **utils.build_system_error_dict(
                            utils.SIMAPP_TRAINING_WORKER_EXCEPTION,
                            utils.SIMAPP_EVENT_ERROR_CODE_503))
                    sys.exit(1)

                if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
                    graph_manager.save_checkpoint()
                else:
                    graph_manager.occasionally_save_checkpoint()
                # Clear any data stored in signals that is no longer necessary
                graph_manager.reset_internal_state()

            if door_man.terminate_now:
                utils.json_format_logger(
                    "Received SIGTERM. Checkpointing before exiting.",
                    **utils.build_system_error_dict(
                        utils.SIMAPP_TRAINING_WORKER_EXCEPTION,
                        utils.SIMAPP_EVENT_ERROR_CODE_500))
                graph_manager.save_checkpoint()
                break

    except Exception as e:
        utils.json_format_logger(
            "An error occured while training: {}. Job failed!.".format(e),
            **utils.build_system_error_dict(
                utils.SIMAPP_TRAINING_WORKER_EXCEPTION,
                utils.SIMAPP_EVENT_ERROR_CODE_503))
        traceback.print_exc()
        sys.exit(1)
    finally:
        graph_manager.data_store.upload_finished_file()
def training_worker(graph_manager, checkpoint_dir, use_pretrained_model, framework):
    """
    restore a checkpoint then perform rollouts using the restored model
    """
    # initialize graph
    task_parameters = TaskParameters()
    task_parameters.__dict__['checkpoint_save_dir'] = checkpoint_dir
    task_parameters.__dict__['checkpoint_save_secs'] = 20
    task_parameters.__dict__['experiment_path'] = INTERMEDIATE_FOLDER

    if framework.lower() == "mxnet":
        task_parameters.framework_type = Frameworks.mxnet
        if hasattr(graph_manager, 'agent_params'):
            for network_parameters in graph_manager.agent_params.network_wrappers.values():
                network_parameters.framework = Frameworks.mxnet
        elif hasattr(graph_manager, 'agents_params'):
            for ap in graph_manager.agents_params:
                for network_parameters in ap.network_wrappers.values():
                    network_parameters.framework = Frameworks.mxnet

    if use_pretrained_model:
        task_parameters.__dict__['checkpoint_restore_dir'] = PRETRAINED_MODEL_DIR

    graph_manager.create_graph(task_parameters)

    # save randomly initialized graph
    graph_manager.save_checkpoint()

    # training loop
    steps = 0
    graph_manager.setup_memory_backend()

    # To handle SIGTERM
    door_man = DoorMan()

    try:
        while (steps < graph_manager.improve_steps.num_steps):
            graph_manager.phase = core_types.RunPhase.TRAIN
            graph_manager.fetch_from_worker(graph_manager.agent_params.algorithm.num_consecutive_playing_steps)
            graph_manager.phase = core_types.RunPhase.UNDEFINED

            if graph_manager.should_train():
                steps += graph_manager.agent_params.algorithm.num_consecutive_playing_steps.num_steps

                graph_manager.phase = core_types.RunPhase.TRAIN
                graph_manager.train()
                graph_manager.phase = core_types.RunPhase.UNDEFINED

                if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
                    graph_manager.save_checkpoint()
                else:
                    graph_manager.occasionally_save_checkpoint()

            if door_man.terminate_now:
                "Received SIGTERM. Checkpointing before exiting."
                graph_manager.save_checkpoint()
                break

    except Exception as e:
        raise RuntimeError("An error occured while training: %s" % e)
    finally:
        print("Terminating training worker")
        graph_manager.data_store.upload_finished_file()