def setUp(self): db.merge_conn( Connection( conn_id='azure_container_instance_test', conn_type='azure_container_instances', login='******', password='******', extra=json.dumps({'tenantId': 'tenant_id', 'subscriptionId': 'subscription_id'}) ) ) self.resources = ResourceRequirements(requests=ResourceRequests( memory_in_gb='4', cpu='1')) with patch('azure.common.credentials.ServicePrincipalCredentials.__init__', autospec=True, return_value=None): with patch('azure.mgmt.containerinstance.ContainerInstanceManagementClient'): self.hook = AzureContainerInstanceHook(conn_id='azure_container_instance_test')
def execute(self, context: dict) -> int: # Check name again in case it was templated. self._check_name(self.name) self._ci_hook = AzureContainerInstanceHook(self.ci_conn_id) if self.fail_if_exists: self.log.info("Testing if container group already exists") if self._ci_hook.exists(self.resource_group, self.name): raise AirflowException("Container group exists") if self.registry_conn_id: registry_hook = AzureContainerRegistryHook(self.registry_conn_id) image_registry_credentials: Optional[list] = [ registry_hook.connection, ] else: image_registry_credentials = None environment_variables = [] for key, value in self.environment_variables.items(): if key in self.secured_variables: e = EnvironmentVariable(name=key, secure_value=value) else: e = EnvironmentVariable(name=key, value=value) environment_variables.append(e) volumes: List[Union[Volume, Volume]] = [] volume_mounts: List[Union[VolumeMount, VolumeMount]] = [] for conn_id, account_name, share_name, mount_path, read_only in self.volumes: hook = AzureContainerVolumeHook(conn_id) mount_name = "mount-%d" % len(volumes) volumes.append( hook.get_file_volume(mount_name, share_name, account_name, read_only)) volume_mounts.append( VolumeMount(name=mount_name, mount_path=mount_path, read_only=read_only)) exit_code = 1 try: self.log.info("Starting container group with %.1f cpu %.1f mem", self.cpu, self.memory_in_gb) if self.gpu: self.log.info("GPU count: %.1f, GPU SKU: %s", self.gpu.count, self.gpu.sku) resources = ResourceRequirements(requests=ResourceRequests( memory_in_gb=self.memory_in_gb, cpu=self.cpu, gpu=self.gpu)) if self.ip_address and not self.ports: self.ports = [ContainerPort(port=80)] self.log.info( "Default port set. Container will listen on port 80") container = Container( name=self.name, image=self.image, resources=resources, command=self.command, environment_variables=environment_variables, volume_mounts=volume_mounts, ports=self.ports, ) container_group = ContainerGroup( location=self.region, containers=[ container, ], image_registry_credentials=image_registry_credentials, volumes=volumes, restart_policy=self.restart_policy, os_type=self.os_type, tags=self.tags, ip_address=self.ip_address, ) self._ci_hook.create_or_update(self.resource_group, self.name, container_group) self.log.info("Container group started %s/%s", self.resource_group, self.name) exit_code = self._monitor_logging(self.resource_group, self.name) self.log.info("Container had exit code: %s", exit_code) if exit_code != 0: raise AirflowException( f"Container had a non-zero exit code, {exit_code}") return exit_code except CloudError: self.log.exception("Could not start container group") raise AirflowException("Could not start container group") finally: if exit_code == 0 or self.remove_on_error: self.on_kill()