def training_worker(graph_manager, task_parameters, user_batch_size,
                    user_episode_per_rollout):
    try:
        # initialize graph
        graph_manager.create_graph(task_parameters)

        # save initial checkpoint
        graph_manager.save_checkpoint()

        # training loop
        steps = 0

        graph_manager.setup_memory_backend()
        graph_manager.signal_ready()

        # To handle SIGTERM
        door_man = utils.DoorMan()

        while steps < graph_manager.improve_steps.num_steps:
            graph_manager.phase = core_types.RunPhase.TRAIN
            graph_manager.fetch_from_worker(graph_manager.agent_params.algorithm.num_consecutive_playing_steps)
            graph_manager.phase = core_types.RunPhase.UNDEFINED

            episodes_in_rollout = graph_manager.memory_backend.get_total_episodes_in_rollout()

            for level in graph_manager.level_managers:
                for agent in level.agents.values():
                    agent.ap.algorithm.num_consecutive_playing_steps.num_steps = episodes_in_rollout
                    agent.ap.algorithm.num_steps_between_copying_online_weights_to_target.num_steps = episodes_in_rollout

            if graph_manager.should_train():
                # Make sure we have enough data for the requested batches
                rollout_steps = graph_manager.memory_backend.get_rollout_steps()
                if any(rollout_steps.values()) <= 0:
                    utils.json_format_logger("No rollout data retrieved from the rollout worker",
                                             **utils.build_system_error_dict(utils.SIMAPP_TRAINING_WORKER_EXCEPTION,
                                                                             utils.SIMAPP_EVENT_ERROR_CODE_500))
                    utils.simapp_exit_gracefully()

                episode_batch_size = user_batch_size if min(rollout_steps.values()) > user_batch_size else 2**math.floor(math.log(min(rollout_steps.values()), 2))
                # Set the batch size to the closest power of 2 such that we have at least two batches, this prevents coach from crashing
                # as  batch size less than 2 causes the batch list to become a scalar which causes an exception
                for level in graph_manager.level_managers:
                    for agent in level.agents.values():
                        agent.ap.network_wrappers['main'].batch_size = episode_batch_size

                steps += 1

                graph_manager.phase = core_types.RunPhase.TRAIN
                graph_manager.train()
                graph_manager.phase = core_types.RunPhase.UNDEFINED

                # Check for Nan's in all agents
                rollout_has_nan = False
                for level in graph_manager.level_managers:
                    for agent in level.agents.values():
                        if np.isnan(agent.loss.get_mean()):
                            rollout_has_nan = True
                if rollout_has_nan:
                    utils.json_format_logger("NaN detected in loss function, aborting training.",
                                             **utils.build_system_error_dict(utils.SIMAPP_TRAINING_WORKER_EXCEPTION,
                                                                             utils.SIMAPP_EVENT_ERROR_CODE_500))
                    utils.simapp_exit_gracefully()

                if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
                    graph_manager.save_checkpoint()
                else:
                    graph_manager.occasionally_save_checkpoint()

                # Clear any data stored in signals that is no longer necessary
                graph_manager.reset_internal_state()

            for level in graph_manager.level_managers:
                for agent in level.agents.values():
                    agent.ap.algorithm.num_consecutive_playing_steps.num_steps = user_episode_per_rollout
                    agent.ap.algorithm.num_steps_between_copying_online_weights_to_target.num_steps = user_episode_per_rollout

            if door_man.terminate_now:
                utils.json_format_logger("Received SIGTERM. Checkpointing before exiting.",
                                         **utils.build_system_error_dict(utils.SIMAPP_TRAINING_WORKER_EXCEPTION,
                                                                         utils.SIMAPP_EVENT_ERROR_CODE_500))
                graph_manager.save_checkpoint()
                break

    except ValueError as err:
        if utils.is_error_bad_ckpnt(err):
            utils.log_and_exit("User modified model: {}".format(err),
                               utils.SIMAPP_TRAINING_WORKER_EXCEPTION,
                               utils.SIMAPP_EVENT_ERROR_CODE_400)
        else:
            utils.log_and_exit("An error occured while training: {}".format(err),
                               utils.SIMAPP_TRAINING_WORKER_EXCEPTION,
                               utils.SIMAPP_EVENT_ERROR_CODE_500)
    except Exception as ex:
        utils.log_and_exit("An error occured while training: {}".format(ex),
                           utils.SIMAPP_TRAINING_WORKER_EXCEPTION,
                           utils.SIMAPP_EVENT_ERROR_CODE_500)
    finally:
        graph_manager.data_store.upload_finished_file()
def training_worker(graph_manager, task_parameters, user_batch_size,
                    user_episode_per_rollout, training_algorithm):
    try:
        # initialize graph
        graph_manager.create_graph(task_parameters)

        # save initial checkpoint
        graph_manager.save_checkpoint()

        # training loop
        steps = 0

        graph_manager.setup_memory_backend()
        graph_manager.signal_ready()

        # To handle SIGTERM
        door_man = utils.DoorMan()

        while steps < graph_manager.improve_steps.num_steps:
            # Collect profiler information only IS_PROFILER_ON is true
            with utils.Profiler(
                    s3_bucket=PROFILER_S3_BUCKET,
                    s3_prefix=PROFILER_S3_PREFIX,
                    output_local_path=TRAINING_WORKER_PROFILER_PATH,
                    enable_profiling=IS_PROFILER_ON):
                graph_manager.phase = core_types.RunPhase.TRAIN
                graph_manager.fetch_from_worker(
                    graph_manager.agent_params.algorithm.
                    num_consecutive_playing_steps)
                graph_manager.phase = core_types.RunPhase.UNDEFINED

                episodes_in_rollout = graph_manager.memory_backend.get_total_episodes_in_rollout(
                )

                for level in graph_manager.level_managers:
                    for agent in level.agents.values():
                        agent.ap.algorithm.num_consecutive_playing_steps.num_steps = episodes_in_rollout
                        agent.ap.algorithm.num_steps_between_copying_online_weights_to_target.num_steps = episodes_in_rollout

                # TODO: Refactor the flow to remove conditional checks for specific algorithms
                # ------------------------sac only---------------------------------------------
                if training_algorithm == TrainingAlgorithm.SAC.value:
                    rollout_steps = graph_manager.memory_backend.get_rollout_steps(
                    )

                    # NOTE: you can train even more iterations than rollout_steps by increasing the number below for SAC
                    agent.ap.algorithm.num_consecutive_training_steps = list(
                        rollout_steps.values())[0]  # rollout_steps[agent]
                # -------------------------------------------------------------------------------
                if graph_manager.should_train():
                    # Make sure we have enough data for the requested batches
                    rollout_steps = graph_manager.memory_backend.get_rollout_steps(
                    )
                    if any(rollout_steps.values()) <= 0:
                        log_and_exit(
                            "No rollout data retrieved from the rollout worker",
                            SIMAPP_TRAINING_WORKER_EXCEPTION,
                            SIMAPP_EVENT_ERROR_CODE_500)

                    # TODO: Refactor the flow to remove conditional checks for specific algorithms
                    # DH: for SAC, check if experience replay memory has enough transitions
                    logger.info("setting trainig algorithm")
                    if training_algorithm == TrainingAlgorithm.SAC.value:
                        replay_mem_size = min([
                            agent.memory.num_transitions()
                            for level in graph_manager.level_managers
                            for agent in level.agents.values()
                        ])
                        episode_batch_size = user_batch_size if replay_mem_size > user_batch_size \
                            else 2**math.floor(math.log(min(rollout_steps.values()), 2))
                    else:
                        logger.info("it is CPPO")
                        episode_batch_size = user_batch_size if min(
                            rollout_steps.values(
                            )) > user_batch_size else 2**math.floor(
                                math.log(min(rollout_steps.values()), 2))
                    # Set the batch size to the closest power of 2 such that we have at least two batches, this prevents coach from crashing
                    # as  batch size less than 2 causes the batch list to become a scalar which causes an exception
                    for level in graph_manager.level_managers:
                        for agent in level.agents.values():
                            for net_key in agent.ap.network_wrappers:
                                agent.ap.network_wrappers[
                                    net_key].batch_size = episode_batch_size

                    steps += 1

                    graph_manager.phase = core_types.RunPhase.TRAIN
                    graph_manager.train()
                    graph_manager.phase = core_types.RunPhase.UNDEFINED

                    # Check for Nan's in all agents
                    rollout_has_nan = False
                    for level in graph_manager.level_managers:
                        for agent in level.agents.values():
                            if np.isnan(agent.loss.get_mean()):
                                rollout_has_nan = True
                    if rollout_has_nan:
                        log_and_exit(
                            "NaN detected in loss function, aborting training.",
                            SIMAPP_TRAINING_WORKER_EXCEPTION,
                            SIMAPP_EVENT_ERROR_CODE_500)

                    if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
                        graph_manager.save_checkpoint()
                    else:
                        graph_manager.occasionally_save_checkpoint()

                    # Clear any data stored in signals that is no longer necessary
                    graph_manager.reset_internal_state()

                for level in graph_manager.level_managers:
                    for agent in level.agents.values():
                        agent.ap.algorithm.num_consecutive_playing_steps.num_steps = user_episode_per_rollout
                        agent.ap.algorithm.num_steps_between_copying_online_weights_to_target.num_steps = user_episode_per_rollout

                if door_man.terminate_now:
                    log_and_exit(
                        "Received SIGTERM. Checkpointing before exiting.",
                        SIMAPP_TRAINING_WORKER_EXCEPTION,
                        SIMAPP_EVENT_ERROR_CODE_500)
                    graph_manager.save_checkpoint()
                    break

    except ValueError as err:
        if utils.is_user_error(err):
            log_and_exit("User modified model: {}".format(err),
                         SIMAPP_TRAINING_WORKER_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_500)
        else:
            log_and_exit("An error occured while training: {}".format(err),
                         SIMAPP_TRAINING_WORKER_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_500)
    except Exception as ex:
        log_and_exit("An error occured while training: {}".format(ex),
                     SIMAPP_TRAINING_WORKER_EXCEPTION,
                     SIMAPP_EVENT_ERROR_CODE_500)
    finally:
        graph_manager.data_store.upload_finished_file()
def training_worker(graph_manager, task_parameters, user_batch_size,
                    user_episode_per_rollout):
    try:
        # initialize graph
        graph_manager.create_graph(task_parameters)

        # save initial checkpoint
        graph_manager.save_checkpoint()

        # training loop
        steps = 0

        graph_manager.setup_memory_backend()
        graph_manager.signal_ready()

        # To handle SIGTERM
        door_man = utils.DoorMan()
        
#         print('---------------- hook.out_dir ----------------')
#         print(hook.out_dir)
        
#         print('---------------- hook.dry_run ----------------')
#         print(hook.dry_run)
        
#         print('---------------- hook.save_config ----------------')
#         print(hook.save_config)
        
#         print('---------------- hook.include_regex ----------------')
#         print(hook.include_regex)
        
#         print('---------------- hook.include_collections ----------------')
#         print(hook.include_collections)
        
#         print('---------------- hook.save_all ----------------')
#         print(hook.save_all)
        
#         print('---------------- hook.include_workers ----------------')
#         print(hook.include_workers)
        
        for level in graph_manager.level_managers:
            for agent in level.agents.values():
                for item in agent.networks.items():
                    name = item[0]
                    network = item[1]
                    
                    print("NETWORK:")
                    print(name)
                    print(network)
                    
                    if network.global_network is not None:
                        network.global_network.optimizer = graph_manager.smdebug_hook.wrap_optimizer(network.global_network.optimizer)

                    if network.online_network is not None:
                        network.online_network.optimizer = graph_manager.smdebug_hook.wrap_optimizer(network.online_network.optimizer)

                    if network.target_network is not None:
                        network.target_network.optimizer = graph_manager.smdebug_hook.wrap_optimizer(network.target_network.optimizer)

        while steps < graph_manager.improve_steps.num_steps:
             # Collect profiler information only IS_PROFILER_ON is true
            with utils.Profiler(s3_bucket=PROFILER_S3_BUCKET, s3_prefix=PROFILER_S3_PREFIX,
                                output_local_path=TRAINING_WORKER_PROFILER_PATH, enable_profiling=IS_PROFILER_ON):
                graph_manager.phase = core_types.RunPhase.TRAIN
                graph_manager.fetch_from_worker(graph_manager.agent_params.algorithm.num_consecutive_playing_steps)
                graph_manager.phase = core_types.RunPhase.UNDEFINED

                episodes_in_rollout = graph_manager.memory_backend.get_total_episodes_in_rollout()

                for level in graph_manager.level_managers:
                    for agent in level.agents.values():
                        agent.ap.algorithm.num_consecutive_playing_steps.num_steps = episodes_in_rollout
                        agent.ap.algorithm.num_steps_between_copying_online_weights_to_target.num_steps = episodes_in_rollout

                if graph_manager.should_train():
                    # Make sure we have enough data for the requested batches
                    rollout_steps = graph_manager.memory_backend.get_rollout_steps()
                    if any(rollout_steps.values()) <= 0:
                        log_and_exit("No rollout data retrieved from the rollout worker",
                                     SIMAPP_TRAINING_WORKER_EXCEPTION,
                                     SIMAPP_EVENT_ERROR_CODE_500)

                    episode_batch_size = user_batch_size if min(rollout_steps.values()) > user_batch_size else 2**math.floor(math.log(min(rollout_steps.values()), 2))
                    # Set the batch size to the closest power of 2 such that we have at least two batches, this prevents coach from crashing
                    # as  batch size less than 2 causes the batch list to become a scalar which causes an exception
                    for level in graph_manager.level_managers:
                        for agent in level.agents.values():
                            agent.ap.network_wrappers['main'].batch_size = episode_batch_size

                    steps += 1

                    graph_manager.phase = core_types.RunPhase.TRAIN
                    graph_manager.train()
                    graph_manager.phase = core_types.RunPhase.UNDEFINED

                    # Check for Nan's in all agents
                    rollout_has_nan = False
                    for level in graph_manager.level_managers:
                        for agent in level.agents.values():
                            if np.isnan(agent.loss.get_mean()):
                                rollout_has_nan = True
                            
                    if rollout_has_nan:
                        log_and_exit("NaN detected in loss function, aborting training.",
                                     SIMAPP_TRAINING_WORKER_EXCEPTION,
                                     SIMAPP_EVENT_ERROR_CODE_500)

                    if graph_manager.agent_params.algorithm.distributed_coach_synchronization_type == DistributedCoachSynchronizationType.SYNC:
                        graph_manager.save_checkpoint()
                    else:
                        graph_manager.occasionally_save_checkpoint()

                    # Clear any data stored in signals that is no longer necessary
                    graph_manager.reset_internal_state()

                for level in graph_manager.level_managers:
                    for agent in level.agents.values():
                        agent.ap.algorithm.num_consecutive_playing_steps.num_steps = user_episode_per_rollout
                        agent.ap.algorithm.num_steps_between_copying_online_weights_to_target.num_steps = user_episode_per_rollout
                        
#                         for item in agent.networks.items():
#                             name = item[0]
#                             network = item[1]

#                             print("NETWORK:")
#                             print(name)
#                             print(network)
                            
#                             print("--------------------------global_network--------------------------")
#                             print(network.global_network)
#                             print("--------------------------online_network--------------------------")
#                             print(network.online_network)
#                             print("--------------------------target_network--------------------------")
#                             print(network.target_network)

#                             if network.global_network is not None:
#                                 hook.add_to_collection("losses", network.global_network.total_loss)
#                                 smdebug_collection.add(network.global_network.total_loss)

#                             if network.online_network is not None:
#                                 hook.add_to_collection("losses", network.online_network.total_loss)
#                                 smdebug_collection.add(network.online_network.total_loss)

#                             if network.target_network is not None:
#                                 hook.add_to_collection("losses", network.target_network.total_loss)
#                                 smdebug_collection.add(network.target_network.total_loss)

                if door_man.terminate_now:
                    log_and_exit("Received SIGTERM. Checkpointing before exiting.",
                                 SIMAPP_TRAINING_WORKER_EXCEPTION,
                                 SIMAPP_EVENT_ERROR_CODE_500)
                    graph_manager.save_checkpoint()
                    break

    except ValueError as err:
        if utils.is_error_bad_ckpnt(err):
            log_and_exit("User modified model: {}"
                             .format(err),
                         SIMAPP_TRAINING_WORKER_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_400)
        else:
            log_and_exit("An error occured while training: {}"
                             .format(err),
                         SIMAPP_TRAINING_WORKER_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_500)
    except Exception as ex:
        log_and_exit("An error occured while training: {}"
                         .format(ex),
                     SIMAPP_TRAINING_WORKER_EXCEPTION,
                     SIMAPP_EVENT_ERROR_CODE_500)
    finally:
        graph_manager.data_store.upload_finished_file()