def setup_ecs_inference_service(
    docker_image_uri,
    framework,
    cluster_arn,
    model_name,
    worker_instance_id,
    ei_accelerator_type=None,
    num_gpus=None,
    region=DEFAULT_REGION,
):
    """
    Function to setup Inference service on ECS
    :param docker_image_uri:
    :param framework:
    :param cluster_arn:
    :param model_name:
    :param worker_instance_id:
    :param num_gpus:
    :param region:
    :return: <tuple> service_name, task_family, revision if all steps passed else Exception
        Cleans up the resources if any step fails
    """
    datetime_suffix = datetime.datetime.now().strftime("%Y%m%d-%H-%M-%S")
    processor = "gpu" if "gpu" in docker_image_uri else "eia" if "eia" in docker_image_uri else "cpu"
    port_mappings = get_ecs_port_mappings(framework)
    log_group_name = f"/ecs/{framework}-inference-{processor}"
    num_cpus = ec2_utils.get_instance_num_cpus(worker_instance_id,
                                               region=region)
    # We assume that about 80% of RAM is free on the instance, since we are not directly querying it to find out
    # what the memory utilization is.
    memory = int(
        ec2_utils.get_instance_memory(worker_instance_id, region=region) * 0.8)
    cluster_name = get_ecs_cluster_name(cluster_arn, region=region)
    # Below values here are just for sanity
    arguments_dict = {
        "family_name": cluster_name,
        "image": docker_image_uri,
        "log_group_name": log_group_name,
        "log_stream_prefix": datetime_suffix,
        "port_mappings": port_mappings,
        "num_cpu": num_cpus,
        "memory": memory,
        "region": region
    }

    if processor == "gpu" and num_gpus:
        arguments_dict["num_gpu"] = num_gpus
    if framework == "tensorflow":
        arguments_dict[
            "environment"] = get_ecs_tensorflow_environment_variables(
                processor, model_name)
        print(f"Added environment variables: {arguments_dict['environment']}")
    elif framework in ["mxnet", "pytorch"]:
        arguments_dict["container_command"] = [
            get_mms_run_command(model_name, processor)
        ]
    if processor == "eia":
        arguments_dict["health_check"] = {
            "retries":
            2,
            "command": [
                "CMD-SHELL",
                "LD_LIBRARY_PATH=/opt/ei_health_check/lib /opt/ei_health_check/bin/health_check"
            ],
            "timeout":
            5,
            "interval":
            30,
            "startPeriod":
            60
        }
        arguments_dict["inference_accelerators"] = {
            "deviceName": "device_1",
            "deviceType": ei_accelerator_type
        }
    try:
        task_family, revision = register_ecs_task_definition(**arguments_dict)
        print(f"Created Task definition - {task_family}:{revision}")

        service_name = create_ecs_service(cluster_name,
                                          f"service-{cluster_name}",
                                          f"{task_family}:{revision}",
                                          region=region)
        print(
            f"Created ECS service - {service_name} with cloudwatch log group - {log_group_name} "
            f"log stream prefix - {datetime_suffix}/{cluster_name}")
        if check_running_task_for_ecs_service(cluster_name,
                                              service_name,
                                              region=region):
            print("Service status verified as running. Running inference ...")
        else:
            raise Exception(f"No task running in the service: {service_name}")
        return service_name, task_family, revision
    except Exception as e:
        raise ECSServiceCreationException(
            f"Setup Inference Service Exception - {e}")
def ecs_training_test_executor(cluster_name,
                               cluster_arn,
                               training_command,
                               image_uri,
                               instance_id,
                               num_gpus=None):
    """
    Function to run training task on ECS; Cleans up the resources after each execution

    :param cluster_name:
    :param cluster_arn:
    :param datetime_suffix:
    :param training_command:
    :param image_uri:
    :param instance_id:
    :param num_gpus:
    :return:
    """
    # Set defaults to satisfy finally case
    task_arn = None
    task_family = None
    revision = None

    # Define constants for arguments to be sent to task def
    image_tag = image_uri.split(':')[-1]
    log_group_name = os.path.join(os.sep, 'ecs', image_tag)
    datetime_suffix = datetime.datetime.now().strftime('%Y%m%d-%H-%M-%S')
    num_cpus = ec2_utils.get_instance_num_cpus(instance_id)
    memory = int(ec2_utils.get_instance_memory(instance_id) * 0.8)

    arguments_dict = {
        "family_name": cluster_name,
        "image": image_uri,
        "log_group_name": log_group_name,
        "log_stream_prefix": datetime_suffix,
        "num_cpu": num_cpus,
        "memory": memory,
        "entrypoint": ["sh", "-c"],
        "container_command": training_command
    }

    if "gpu" in image_tag and num_gpus:
        arguments_dict["num_gpu"] = str(num_gpus)
    try:
        task_family, revision = register_ecs_task_definition(**arguments_dict)
        print(f"Created Task definition - {task_family}:{revision}")

        task_arn = create_ecs_task(cluster_name, f"{task_family}:{revision}")
        print(
            f"Created ECS task - {task_arn} with cloudwatch log group - {log_group_name} log stream prefix - "
            f"{os.path.join(datetime_suffix, cluster_name)}")
        print("Waiting for task to stop ...")

        if ecs_task_waiter(cluster_name, [task_arn], "tasks_stopped"):
            ret_codes = describe_ecs_task_exit_status(cluster_name, task_arn)
            if ret_codes:

                # Assemble error message if we have nonzero return codes
                error_msg = "Failures:\n"
                for code in ret_codes:
                    add_on = "------------------\n"
                    for key, value in code.items():
                        add_on += f"{key}: {value}\n"
                    error_msg += add_on
                raise ECSTrainingTestFailure(error_msg)

            # Return gracefully if task stops
            return

        # Raise error if the task does not stop
        raise ECSTaskNotStoppedError(f"Task not stopped {task_arn}")
    finally:
        tear_down_ecs_training_task(cluster_arn, task_arn, task_family,
                                    revision)