Ejemplo n.º 1
0
def subscription_profile() -> Profile:
    """Return the Azure CLI profile

    Returns:
        azure.cli.core._profile.Profile: Azure profile
    """
    logger = logging.getLogger(__name__)
    try:
        return get_cli_profile()
    except CLIError:
        logger.info("Not logged in, running az login")
        run_az_cli_login()
        return get_cli_profile()
def use_token_authentication():
  """Setup storage tokens by authenticating with cli credentials
  and use management APIs
  Args:
    account_name (str): The storage account name for which to authenticate
  """
  import os
  from tensorflow.python.platform import tf_logging as log

  try:
    from azure.common import credentials, client_factory
    from azure.mgmt import storage
    from msrestazure.azure_active_directory import get_msi_token
  except ModuleNotFoundError:
    log.error('Please install azure libraries with '
              '`python -m pip install -U azure-mgmt-storage azure-cli-core msrestazure`'
              'to use the cli authentication method')
    return

  if 'MSI_ENDPOINT' in os.environ:
    log.info("Using credentials for managed identity")
    token_type, access_token, token_entry = get_msi_token(resource='https://storage.azure.com/')
  else:
    log.info("Using credentials from az cli")
    creds, subscription_id, tenant_id = credentials.get_cli_profile().get_raw_token(resource='https://storage.azure.com/')
    access_token = creds[1]

  os.environ['TF_AZURE_ACCESS_TOKEN'] = access_token
Ejemplo n.º 3
0
    def authenticate_cli(self) -> Credentials:
        """
        Implements authentication for the Azure provider
        """
        try:

            # Set logging level to error for libraries as otherwise generates a lot of warnings
            logging.getLogger('adal-python').setLevel(logging.ERROR)
            logging.getLogger('msrest').setLevel(logging.ERROR)
            logging.getLogger('msrestazure.azure_active_directory').setLevel(logging.ERROR)
            logging.getLogger('urllib3').setLevel(logging.ERROR)

            arm_credentials, subscription_id, tenant_id = get_azure_cli_credentials(with_tenant=True)
            aad_graph_credentials, placeholder_1, placeholder_2 = get_azure_cli_credentials(
                with_tenant=True, resource='https://graph.windows.net',
            )

            profile = get_cli_profile()

            return Credentials(
                arm_credentials, aad_graph_credentials, tenant_id=tenant_id,
                current_user=profile.get_current_account_user(), subscription_id=subscription_id,
            )

        except HttpResponseError as e:
            if ', AdalError: Unsupported wstrust endpoint version. ' \
                    'Current supported version is wstrust2005 or wstrust13.' in e.args:
                logger.error(
                    f'You are likely authenticating with a Microsoft Account. \
                    This authentication mode only supports Azure Active Directory principal authentication.\
                    {e}',
                )

            raise e
Ejemplo n.º 4
0
def _configure_resource_group(config):
    # TODO: look at availability sets
    # https://docs.microsoft.com/en-us/azure/virtual-machines/windows/tutorial-availability-sets
    subscription_id = config["provider"].get("subscription_id")
    if subscription_id is None:
        subscription_id = get_cli_profile().get_subscription_id()
    resource_client = ResourceManagementClient(AzureCliCredential(),
                                               subscription_id)
    config["provider"]["subscription_id"] = subscription_id
    logger.info("Using subscription id: %s", subscription_id)

    assert ("resource_group" in config["provider"]
            ), "Provider config must include resource_group field"
    resource_group = config["provider"]["resource_group"]

    assert (
        "location"
        in config["provider"]), "Provider config must include location field"
    params = {"location": config["provider"]["location"]}

    if "tags" in config["provider"]:
        params["tags"] = config["provider"]["tags"]

    logger.info("Creating/Updating Resource Group: %s", resource_group)
    rg_create_or_update = get_azure_sdk_function(
        client=resource_client.resource_groups,
        function_name="create_or_update")
    rg_create_or_update(resource_group_name=resource_group, parameters=params)

    # load the template file
    current_path = Path(__file__).parent
    template_path = current_path.joinpath("azure-config-template.json")
    with open(template_path, "r") as template_fp:
        template = json.load(template_fp)

    # choose a random subnet, skipping most common value of 0
    random.seed(resource_group)
    subnet_mask = "10.{}.0.0/16".format(random.randint(1, 254))

    parameters = {
        "properties": {
            "mode": DeploymentMode.incremental,
            "template": template,
            "parameters": {
                "subnet": {
                    "value": subnet_mask
                }
            },
        }
    }

    create_or_update = get_azure_sdk_function(
        client=resource_client.deployments, function_name="create_or_update")
    create_or_update(
        resource_group_name=resource_group,
        deployment_name="ray-config",
        parameters=parameters,
    ).wait()

    return config
 def _get_cli_profile(self, subscription_id):  # pylint:disable=no-self-use
     try:
         from azure.cli.core.util import CLIError
         from azure.cli.core.cloud import get_active_cloud
         try:
             profile = get_cli_profile()
             cloud = get_active_cloud()
             subscription = profile.get_subscription(
                 subscription=subscription_id)
             return profile, subscription['id'], cloud.endpoints
         except CLIError:
             raise ValueError(
                 "Unable to load Azure CLI authenticated session. Please "
                 "run the 'az login' command or supply an AAD credentials "
                 "object from azure.common.credentials.")
     except ImportError:
         raise ValueError(
             'Unable to load Azure CLI authenticated session. Please '
             'supply an AAD credentials object from azure.common.credentials'
         )
     except (AttributeError, KeyError, TypeError) as error:
         raise ValueError(
             'Unable to load Azure CLI authenticated session. There is '
             'a version conflict with azure-cli-core. Please check for '
             'updates or report this issue at '
             'github.com/Azure/azure-batch-cli-extensions:\n{}'.format(
                 str(error)))
Ejemplo n.º 6
0
def query_microsoft_graph(
    method: str,
    resource: str,
    params: Optional[Dict] = None,
    body: Optional[Dict] = None,
) -> Any:
    profile = get_cli_profile()
    (token_type, access_token, _), _, _ = profile.get_raw_token(
        resource="https://graph.microsoft.com"
    )
    url = urllib.parse.urljoin("https://graph.microsoft.com/v1.0/", resource)
    headers = {
        "Authorization": "%s %s" % (token_type, access_token),
        "Content-Type": "application/json",
    }
    response = requests.request(
        method=method, url=url, headers=headers, params=params, json=body
    )

    response.status_code

    if 200 <= response.status_code < 300:
        try:
            return response.json()
        except ValueError:
            return None
    else:
        error_text = str(response.content, encoding="utf-8", errors="backslashreplace")
        raise GraphQueryError(
            "request did not succeed: HTTP %s - %s"
            % (response.status_code, error_text),
            response.status_code,
        )
Ejemplo n.º 7
0
def get_tenant_id(subscription_id: Optional[str] = None) -> str:
    profile = get_cli_profile()
    _, _, tenant_id = profile.get_raw_token(resource=GRAPH_RESOURCE,
                                            subscription=subscription_id)
    if isinstance(tenant_id, str):
        return tenant_id
    else:
        raise Exception(
            f"unable to retrive tenant_id for subscription {subscription_id}")
Ejemplo n.º 8
0
def query_microsoft_graph(
    method: str,
    resource: str,
    params: Optional[Dict] = None,
    body: Optional[Dict] = None,
    subscription: Optional[str] = None,
) -> Dict:
    profile = get_cli_profile()
    (token_type, access_token,
     _), _, _ = profile.get_raw_token(resource=GRAPH_RESOURCE,
                                      subscription=subscription)
    url = urllib.parse.urljoin(f"{GRAPH_RESOURCE_ENDPOINT}/", resource)
    headers = {
        "Authorization": "%s %s" % (token_type, access_token),
        "Content-Type": "application/json",
    }
    response = requests.request(method=method,
                                url=url,
                                headers=headers,
                                params=params,
                                json=body)
    if 200 <= response.status_code < 300:
        if response.content and response.content.strip():
            json = response.json()
            if isinstance(json, Dict):
                return json
            else:
                raise GraphQueryError(
                    f"invalid data received expected a json object: HTTP {response.status_code} - {json}",
                    response.status_code,
                )
        else:
            return {}
    else:
        error_text = str(response.content,
                         encoding="utf-8",
                         errors="backslashreplace")
        raise GraphQueryError(
            f"request did not succeed: HTTP {response.status_code} - {error_text}",
            response.status_code,
        )
Ejemplo n.º 9
0
 def get_subscription_id(self) -> str:
     if self.subscription_id:
         return self.subscription_id
     profile = get_cli_profile()
     self.subscription_id = cast(str, profile.get_subscription_id())
     return self.subscription_id
Ejemplo n.º 10
0
 def get_subscription_id(self):
     profile = get_cli_profile()
     return profile.get_subscription_id()
Ejemplo n.º 11
0
def select(sub_name_or_id):
    profile = get_cli_profile()
    profile.set_active_subscription(sub_name_or_id)
Ejemplo n.º 12
0
def get_azure_cli_credentials_non_default_sub(resource: str,
                                              subscription: str) -> Any:
    profile = get_cli_profile()
    cred, _, _ = profile.get_login_credentials(resource=resource,
                                               subscription_id=subscription)
    return cred
Ejemplo n.º 13
0
 def get_subscription_id(self) -> str:
     profile = get_cli_profile()
     return cast(str, profile.get_subscription_id())
Ejemplo n.º 14
0
                                                    **kwargs)
        kwargs["storage_account_id"] = storage_account_id
    print("Created Storage account.")

    # create batch account
    with Spinner():
        batch_account_id = create_batch_account(creds, subscription_id,
                                                **kwargs)
    print("Created Batch account.")

    # create vnet with a subnet
    # subnet_id = create_vnet(creds, subscription_id)

    # create AAD application and service principal
    with Spinner():
        profile = credentials.get_cli_profile()
        aad_cred, subscription_id, tenant_id = profile.get_login_credentials(
            resource=AZURE_PUBLIC_CLOUD.endpoints.
            active_directory_graph_resource_id)
        application_id, service_principal_object_id, application_credential = create_aad_user(
            aad_cred, tenant_id, **kwargs)

    print("Created Azure Active Directory service principal.")

    with Spinner():
        create_role_assignment(creds, subscription_id, resource_group_id,
                               service_principal_object_id)
    print("Configured permissions.")

    secrets = format_secrets(
        **{
Ejemplo n.º 15
0
import azure.mgmt.storage as st
import azure.common.credentials as creds

az_cred = creds.get_cli_profile()
print(az_cred)
Ejemplo n.º 16
0
    def __init__(
        self,
        location: str = None,
        resource_group: str = None,
        vnet: str = None,
        security_group: str = None,
        public_ingress: bool = None,
        vm_size: str = None,
        scheduler_vm_size: str = None,
        vm_image: dict = {},
        disk_size: int = None,
        bootstrap: bool = None,
        auto_shutdown: bool = None,
        docker_image=None,
        debug: bool = False,
        marketplace_plan: dict = {},
        **kwargs,
    ):
        self.config = ClusterConfig(dask.config.get("cloudprovider.azure", {}))
        self.scheduler_class = AzureVMScheduler
        self.worker_class = AzureVMWorker
        self.location = self.config.get("location", override_with=location)
        if self.location is None:
            raise ConfigError("You must configure a location")
        self.resource_group = self.config.get("resource_group",
                                              override_with=resource_group)
        if self.resource_group is None:
            raise ConfigError("You must configure a resource_group")
        self.public_ingress = self.config.get("azurevm.public_ingress",
                                              override_with=public_ingress)
        self.subscription_id = get_cli_profile().get_subscription_id()
        self.credentials = DefaultAzureCredential()
        self.compute_client = ComputeManagementClient(self.credentials,
                                                      self.subscription_id)
        self.network_client = NetworkManagementClient(self.credentials,
                                                      self.subscription_id)
        self.vnet = self.config.get("azurevm.vnet", override_with=vnet)
        if self.vnet is None:
            raise ConfigError("You must configure a vnet")
        self.security_group = self.config.get("azurevm.security_group",
                                              override_with=security_group)
        if self.security_group is None:
            raise ConfigError(
                "You must configure a security group which allows traffic on 8786 and 8787"
            )
        self.vm_size = self.config.get("azurevm.vm_size",
                                       override_with=vm_size)
        self.disk_size = self.config.get("azurevm.disk_size",
                                         override_with=disk_size)
        if self.disk_size > 1023:
            raise ValueError(
                "VM OS disk canot be larger than 1023. Please change the ``disk_size`` config option."
            )
        self.scheduler_vm_size = self.config.get(
            "azurevm.scheduler_vm_size", override_with=scheduler_vm_size)
        if self.scheduler_vm_size is None:
            self.scheduler_vm_size = self.vm_size
        self.gpu_instance = ("_NC" in self.vm_size.upper()
                             or "_ND" in self.vm_size.upper())
        self.vm_image = self.config.get("azurevm.vm_image")
        for key in vm_image:
            self.vm_image[key] = vm_image[key]
        self.bootstrap = self.config.get("azurevm.bootstrap",
                                         override_with=bootstrap)
        self.auto_shutdown = self.config.get("azurevm.auto_shutdown",
                                             override_with=auto_shutdown)
        self.debug = debug
        self.marketplace_plan = marketplace_plan or self.config.get(
            "azurevm.marketplace_plan")
        if self.marketplace_plan:
            # Check that self.marketplace_plan contains the right options with values
            if not all(
                    self.marketplace_plan.get(item, "") != ""
                    for item in ["name", "publisher", "product"]):
                raise ConfigError(
                    """To create a virtual machine from Marketplace image or a custom image sourced
                from a Marketplace image with a plan, all 3 fields 'name', 'publisher' and 'product' must be passed."""
                )

        self.options = {
            "cluster": self,
            "config": self.config,
            "security_group": self.security_group,
            "location": self.location,
            "vm_image": self.vm_image,
            "disk_size": self.disk_size,
            "gpu_instance": self.gpu_instance,
            "bootstrap": self.bootstrap,
            "auto_shutdown": self.auto_shutdown,
            "docker_image": self.docker_image,
            "marketplace_plan": self.marketplace_plan,
        }
        self.scheduler_options = {
            "vm_size": self.scheduler_vm_size,
            "public_ingress": self.public_ingress,
            **self.options,
        }
        self.worker_options = {"vm_size": self.vm_size, **self.options}
        super().__init__(debug=debug, **kwargs)