def list_sct_runners(cls) -> list[SctRunnerInfo]: azure_service = AzureService() sct_runners = [] for instance in list_instances_azure( tags_dict={"NodeType": cls.NODE_TYPE}, verbose=True): if launch_time := instance.tags.get("launch_time") or None: try: launch_time = datetime_from_formatted( date_string=launch_time) except ValueError as exc: LOGGER.warning("Value of `launch_time' tag is invalid: %s", exc) launch_time = None sct_runners.append( SctRunnerInfo( sct_runner_class=cls, cloud_service_instance=azure_service, region_az=instance.location, instance=instance, instance_name=instance.name, public_ips=[ azure_service.get_virtual_machine_ips( virtual_machine=instance).public_ip ], launch_time=launch_time, keep=instance.tags.get("keep"), keep_action=instance.tags.get("keep_action"), ))
def __init__( self, test_id: str, region: str, # pylint: disable=unused-argument azure_service: AzureService = AzureService(), **kwargs): super().__init__(test_id, region) self._azure_service: AzureService = azure_service self._cache: Dict[str, VmInstance] = {} LOGGER.debug("getting resources for %s...", self._resource_group_name) self._rg_provider = ResourceGroupProvider(self._resource_group_name, self._region, self._azure_service) self._network_sec_group_provider = NetworkSecurityGroupProvider( self._resource_group_name, self._region, self._azure_service) self._vnet_provider = VirtualNetworkProvider(self._resource_group_name, self._region, self._azure_service) self._subnet_provider = SubnetProvider(self._resource_group_name, self._azure_service) self._ip_provider = IpAddressProvider(self._resource_group_name, self._region, self._azure_service) self._nic_provider = NetworkInterfaceProvider( self._resource_group_name, self._region, self._azure_service) self._vm_provider = VirtualMachineProvider(self._resource_group_name, self._region, self._azure_service) for v_m in self._vm_provider.list(): self._cache[v_m.name] = self._vm_to_instance(v_m)
def azure_service(tmp_path_factory) -> AzureService: # pylint: disable=no-self-use run_on_real_azure = False # make it True to test with real Azure if run_on_real_azure: # When true this becomes e2e test - takes around 8 minutes (2m provisioning, 6 min cleanup with wait=True) return AzureService() resources_path = tmp_path_factory.mktemp("azure-provision") # print(resources_path) return FakeAzureService(resources_path)
def get_scylla_images( # pylint: disable=too-many-branches scylla_version: str, region_name: str, arch: VmArch = VmArch.X86) -> list[GalleryImageVersion]: version_bucket = scylla_version.split(":", 1) only_latest = False tags_to_search = {'arch': arch.value} if len(version_bucket) == 1: if '.' in scylla_version: # Plain version, like 4.5.0 tags_to_search[ 'ScyllaVersion'] = lambda ver: ver and ver.startswith( scylla_version) else: # Commit id d28c3ee75183a6de3e9b474127b8c0b4d01bbac2 tags_to_search['scylla-git-commit'] = scylla_version else: # Branched version, like master:latest branch, build_id = version_bucket tags_to_search['branch'] = branch if build_id == 'latest': only_latest = True elif build_id == 'all': pass else: tags_to_search['build-id'] = build_id output = [] with suppress(AzureResourceNotFoundError): gallery_image_versions = AzureService( ).compute.images.list_by_resource_group( resource_group_name="SCYLLA-IMAGES", ) for image in gallery_image_versions: # Filter by region if image.location != region_name: continue # Filter by tags for tag_name, expected_value in tags_to_search.items(): actual_value = image.tags.get(tag_name) if callable(expected_value): if not expected_value(actual_value): break elif expected_value != actual_value: break else: output.append(image) output = sorted(output, key=lambda img: img.tags.get('build_id')) if only_latest: return output[-1:] return output
class SubnetProvider: _resource_group_name: str _azure_service: AzureService = AzureService() _cache: Dict[str, Subnet] = field(default_factory=dict) def __post_init__(self): """Discover existing subnets for resource group.""" try: vnets = self._azure_service.network.virtual_networks.list( self._resource_group_name) for vnet in vnets: subnets = self._azure_service.network.subnets.list( self._resource_group_name, vnet.name) for subnet in subnets: self._cache[f"{vnet.name}-{subnet.name}"] = subnet except ResourceNotFoundError: pass def get_or_create(self, vnet_name: str, network_sec_group_id: str, subnet_name: str = "default") -> Subnet: cache_name = f"{vnet_name}-{subnet_name}" if cache_name in self._cache: return self._cache[cache_name] LOGGER.info("Creating subnet in resource group %s...", self._resource_group_name) self._azure_service.network.subnets.begin_create_or_update( resource_group_name=self._resource_group_name, virtual_network_name=vnet_name, subnet_name=subnet_name, subnet_parameters={ "address_prefix": "10.0.0.0/24", "network_security_group": { "id": network_sec_group_id, }, }, ).wait() subnet = self._azure_service.network.subnets.get( self._resource_group_name, vnet_name, subnet_name) LOGGER.info("Provisioned subnet %s for the %s vnet", subnet.name, vnet_name) self._cache[cache_name] = subnet return subnet def clear_cache(self): self._cache = {}
class NetworkSecurityGroupProvider: _resource_group_name: str _region: str _azure_service: AzureService = AzureService() _cache: Dict[str, NetworkSecurityGroup] = field(default_factory=dict) def __post_init__(self): """Discover existing security groups for resource group.""" try: groups = self._azure_service.network.network_security_groups.list( self._resource_group_name) for group in groups: sec_group = self._azure_service.network.network_security_groups.get( self._resource_group_name, group.name) self._cache[group.name] = sec_group except ResourceNotFoundError: pass def get_or_create(self, security_rules: Iterable, name="default") -> NetworkSecurityGroup: """Creates or gets (if already exists) security group""" if name in self._cache: return self._cache[name] open_ports_rules = rules_to_payload(security_rules) LOGGER.info( "Creating SCT network security group in resource group %s...", self._resource_group_name) self._azure_service.network.network_security_groups.begin_create_or_update( resource_group_name=self._resource_group_name, network_security_group_name=name, parameters={ "location": self._region, "security_rules": open_ports_rules, }, ).wait() network_sec_group = self._azure_service.network.network_security_groups.get( self._resource_group_name, name) LOGGER.info("Provisioned security group %s in the %s resource group", network_sec_group.name, self._resource_group_name) self._cache[name] = network_sec_group return network_sec_group def clear_cache(self): self._cache = {}
class VirtualNetworkProvider: _resource_group_name: str _region: str _azure_service: AzureService = AzureService() _cache: Dict[str, VirtualNetwork] = field(default_factory=dict) def __post_init__(self): """Discover existing virtual networks for resource group.""" try: vnets = self._azure_service.network.virtual_networks.list( self._resource_group_name) for vnet in vnets: vnet = self._azure_service.network.virtual_networks.get( self._resource_group_name, vnet.name) self._cache[vnet.name] = vnet except ResourceNotFoundError: pass def get_or_create(self, name: str = "default") -> VirtualNetwork: if name in self._cache: return self._cache[name] LOGGER.info("Creating vnet in resource group %s...", self._resource_group_name) self._azure_service.network.virtual_networks.begin_create_or_update( resource_group_name=self._resource_group_name, virtual_network_name=name, parameters={ "location": self._region, "address_space": { "address_prefixes": ["10.0.0.0/16"], } }).wait() vnet = self._azure_service.network.virtual_networks.get( self._resource_group_name, name) LOGGER.info("Provisioned vnet %s in the %s resource group", vnet.name, self._resource_group_name) self._cache[vnet.name] = vnet return vnet def clear_cache(self): self._cache = {}
class ResourceGroupProvider: """Class for providing resource groups and taking care about discovery existing ones.""" _name: str _region: str _azure_service: AzureService = AzureService() _cache: Optional[ResourceGroup] = field(default=None) def __post_init__(self): """Discover existing resource group for this provider.""" try: resource_group = self._azure_service.resource.resource_groups.get(self._name) assert resource_group.location == self._region, \ f"resource group {resource_group.name} does not belong to {self._region} region (location)" self._cache = resource_group except ResourceNotFoundError: pass def get_or_create(self) -> ResourceGroup: if self._cache is not None: LOGGER.debug("Found resource group: %s in cache", self._name) return self._cache LOGGER.info("Creating %s SCT resource group in region %s...", self._name, self._region) resource_group = self._azure_service.resource.resource_groups.create_or_update( resource_group_name=self._name, parameters={ "location": self._region }, ) LOGGER.info("Provisioned resource group %s in the %s region", resource_group.name, resource_group.location) self._cache = resource_group return resource_group def delete(self, wait: bool = False): """Deletes resource group along with all contained resources.""" LOGGER.info("Initiating cleanup of resource group: %s...", self._name) task = self._azure_service.resource.resource_groups.begin_delete(self._name) LOGGER.info("Cleanup initiated") self._cache = None if wait is True: LOGGER.info("Waiting for cleanup completion") task.wait()
def discover_regions(cls, test_id: str = "", azure_service: AzureService = AzureService(), **kwargs) -> List["AzureProvisioner"]: # pylint: disable=arguments-differ,unused-argument """Discovers provisioners for in each region for given test id. If test_id is not provided, it discovers all related to SCT provisioners.""" all_resource_groups = [ rg for rg in azure_service.resource.resource_groups.list() if rg.name.startswith("SCT-") ] if test_id: provisioner_params = [(test_id, rg.location, azure_service) for rg in all_resource_groups if test_id in rg.name] else: # extract test_id from rg names where rg.name format is: SCT-<test_id>-<region> provisioner_params = [ (test_id, rg.location, azure_service) for rg in all_resource_groups if (test_id := "-".join(rg.name.split("-")[1:-1])) ] return [cls(*params) for params in provisioner_params]
def azure_service(self) -> AzureService: # pylint: disable=no-self-use; pylint doesn't now about cached_property return AzureService()
class IpAddressProvider: _resource_group_name: str _region: str _azure_service: AzureService = AzureService() _cache: Dict[str, PublicIPAddress] = field(default_factory=dict) def __post_init__(self): """Discover existing ip addresses for resource group.""" try: ips = self._azure_service.network.public_ip_addresses.list( self._resource_group_name) for ip in ips: ip = self._azure_service.network.public_ip_addresses.get( self._resource_group_name, ip.name) self._cache[ip.name] = ip except ResourceNotFoundError: pass def get_or_create(self, names: List[str] = "default", version: str = "IPV4") -> List[PublicIPAddress]: addresses = [] pollers = [] for name in names: ip_name = self._get_ip_name(name, version) if ip_name in self._cache: addresses.append(self._cache[ip_name]) continue LOGGER.info("Creating public_ip %s in resource group %s...", ip_name, self._resource_group_name) poller = self._azure_service.network.public_ip_addresses.begin_create_or_update( resource_group_name=self._resource_group_name, public_ip_address_name=ip_name, parameters={ "location": self._region, "sku": { "name": "Standard", }, "public_ip_allocation_method": "Static", "public_ip_address_version": version.upper(), }, ) pollers.append((ip_name, poller)) for ip_name, poller in pollers: poller.wait() # need to get it separately as seems not always it gets created even if result() returns proper ip_address. address = self._azure_service.network.public_ip_addresses.get( self._resource_group_name, ip_name) LOGGER.info( "Provisioned public ip %s (%s) in the %s resource group", address.name, address.ip_address, self._resource_group_name) self._cache[ip_name] = address addresses.append(address) return addresses def get(self, name: str = "default", version: str = "IPV4"): ip_name = self._get_ip_name(name, version) return self._cache[ip_name] def delete(self, ip_address: PublicIPAddress): # just remove from cache as it should be deleted along with network interface del self._cache[ip_address.name] def clear_cache(self): self._cache = {} @staticmethod def _get_ip_name(name: str, version: str): return f"{name}-{version.lower()}"
class VirtualMachineProvider: _resource_group_name: str _region: str _azure_service: AzureService = AzureService() _cache: Dict[str, VirtualMachine] = field(default_factory=dict) def __post_init__(self): """Discover existing virtual machines for resource group.""" try: v_ms = self._azure_service.compute.virtual_machines.list( self._resource_group_name) for v_m in v_ms: v_m = self._azure_service.compute.virtual_machines.get( self._resource_group_name, v_m.name) self._cache[v_m.name] = v_m except ResourceNotFoundError: pass def get_or_create(self, definitions: List[InstanceDefinition], nics_ids: List[str], pricing_model: PricingModel) -> List[VirtualMachine]: # pylint: disable=too-many-locals v_ms = [] pollers = [] error_to_raise = None for definition, nic_id in zip(definitions, nics_ids): if definition.name in self._cache: v_ms.append(self._cache[definition.name]) continue LOGGER.info("Creating '%s' VM in resource group %s...", definition.name, self._resource_group_name) LOGGER.info("Instance params: %s", definition) params = { "location": self._region, "tags": definition.tags | { "ssh_user": definition.user_name, "ssh_key": definition.ssh_key.name }, "hardware_profile": { "vm_size": definition.type, }, "network_profile": { "network_interfaces": [{ "id": nic_id, "properties": { "deleteOption": "Delete" } }], }, } if definition.user_data is None: # in case we use specialized image, we don't change things like computer_name, usernames, ssh_keys os_profile = {} else: builder = UserDataBuilder( user_data_objects=definition.user_data) custom_data = builder.build_user_data_yaml() os_profile = self._get_os_profile( computer_name=definition.name, admin_username=definition.user_name, admin_password=binascii.hexlify(os.urandom(20)).decode(), ssh_public_key=definition.ssh_key.public_key.decode(), custom_data=custom_data) params.update({ "user_data": base64.b64encode( builder.get_scylla_machine_image_json().encode( 'utf-8')).decode('latin-1') }) storage_profile = self._get_scylla_storage_profile( image_id=definition.image_id, name=definition.name, disk_size=definition.root_disk_size) params.update(os_profile) params.update(storage_profile) params.update(self._get_pricing_params(pricing_model)) try: poller = self._azure_service.compute.virtual_machines.begin_create_or_update( resource_group_name=self._resource_group_name, vm_name=definition.name, parameters=params) pollers.append((definition, poller)) except AzureError as err: LOGGER.error( "Error when sending create vm request for VM %s: %s", definition.name, str(err)) error_to_raise = err for definition, poller in pollers: try: poller.wait() v_m = self._azure_service.compute.virtual_machines.get( self._resource_group_name, definition.name) LOGGER.info("Provisioned VM %s in the %s resource group", v_m.name, self._resource_group_name) self._cache[v_m.name] = v_m v_ms.append(v_m) except AzureError as err: LOGGER.error("Error when waiting for VM %s: %s", definition.name, str(err)) error_to_raise = err if error_to_raise: raise ProvisionError(error_to_raise) return v_ms def list(self): return list(self._cache.values()) def delete(self, name: str, wait: bool = True): LOGGER.info("Triggering termination of instance: %s", name) self._azure_service.compute.virtual_machines.begin_update( self._resource_group_name, vm_name=name, parameters={ "storageProfile": { "osDisk": { "createOption": "FromImage", "deleteOption": "Delete" } } }) task = self._azure_service.compute.virtual_machines.begin_delete( self._resource_group_name, vm_name=name) if wait is True: LOGGER.info("Waiting for termination of instance: %s...", name) task.wait() LOGGER.info("Instance %s has been terminated.", name) del self._cache[name] def reboot(self, name: str, wait: bool = True) -> None: LOGGER.info("Triggering reboot of instance: %s", name) task = self._azure_service.compute.virtual_machines.begin_restart( self._resource_group_name, vm_name=name) if wait is True: LOGGER.info("Waiting for reboot of instance: %s...", name) task.wait() LOGGER.info("Instance %s has been rebooted.", name) def add_tags(self, name: str, tags: Dict[str, str]) -> VirtualMachine: """Adds tags to instance (with waiting for completion)""" if name not in self._cache: raise AttributeError( f"Instance '{name}' does not exist in resource group '{self._resource_group_name}'" ) current_tags = self._cache[name].tags current_tags.update(tags) self._azure_service.compute.virtual_machines.begin_update( self._resource_group_name, name, parameters={ "tags": current_tags }).wait() v_m = self._azure_service.compute.virtual_machines.get( self._resource_group_name, name) self._cache[v_m.name] = v_m return v_m @staticmethod def _get_os_profile(computer_name: str, admin_username: str, admin_password: str, ssh_public_key: str, custom_data: str) -> Dict[str, Any]: os_profile = { "os_profile": { "computer_name": computer_name, "admin_username": admin_username, "admin_password": admin_password, "custom_data": base64.b64encode( custom_data.encode('utf-8')).decode('latin-1'), "linux_configuration": { "disable_password_authentication": True, "ssh": { "public_keys": [{ "path": f"/home/{admin_username}/.ssh/authorized_keys", "key_data": ssh_public_key, }], }, }, } } return os_profile @staticmethod def _get_scylla_storage_profile( image_id: str, name: str, disk_size: Optional[int] = None) -> Dict[str, Any]: """Creates storage profile based on image_id. image_id may refer to scylla-crafted images (starting with '/subscription') or to 'Urn' of image (see output of e.g. `az vm image list --output table`)""" storage_profile = { "storage_profile": { "os_disk": { "name": f"{name}-os-disk", "os_type": "linux", "caching": "ReadWrite", "create_option": "FromImage", "deleteOption": "Delete", # somehow deletion of VM does not delete os_disk anyway... "managed_disk": { "storage_account_type": "Premium_LRS", # SSD } } | ({} if disk_size is None else { "disk_size_gb": disk_size }), } } if image_id.startswith("/subscriptions/"): storage_profile.update({ "storage_profile": { "image_reference": { "id": image_id }, "deleteOption": "Delete" } }) else: image_reference_values = image_id.split(":") storage_profile.update({ "storage_profile": { "image_reference": { "publisher": image_reference_values[0], "offer": image_reference_values[1], "sku": image_reference_values[2], "version": image_reference_values[3], }, } }) return storage_profile @staticmethod def _get_pricing_params(pricing_model: PricingModel): if pricing_model != PricingModel.ON_DEMAND: return { "priority": "Spot", # possible values are "Regular", "Low", or "Spot" "eviction_policy": "Delete", # can be "Deallocate" or "Delete", Deallocate leaves disks intact "billing_profile": { "max_price": -1, # -1 indicates the VM shouldn't be evicted for price reasons } } else: return {"priority": "Regular"} def clear_cache(self): self._cache = {}
class NetworkInterfaceProvider: _resource_group_name: str _region: str _azure_service: AzureService = AzureService() _cache: Dict[str, NetworkInterface] = field(default_factory=dict) def __post_init__(self): """Discover existing network interfaces for resource group.""" try: nics = self._azure_service.network.network_interfaces.list( self._resource_group_name) for nic in nics: nic = self._azure_service.network.network_interfaces.get( self._resource_group_name, nic.name) self._cache[nic.name] = nic except ResourceNotFoundError: pass def get(self, name: str) -> NetworkInterface: return self._cache[self.get_nic_name(name)] def get_or_create(self, subnet_id: str, ip_addresses_ids: List[str], names: List[str]) -> List[NetworkInterface]: """Creates or gets (if already exists) network interface""" nics = [] pollers = [] for name, address in zip(names, ip_addresses_ids): nic_name = self.get_nic_name(name) if nic_name in self._cache: nics.append(self._cache[nic_name]) continue parameters = { "location": self._region, "ip_configurations": [{ "name": nic_name, "subnet": { "id": subnet_id, }, }], "enable_accelerated_networking": True, } parameters["ip_configurations"][0]["public_ip_address"] = { "id": address, "properties": { "deleteOption": "Delete" } } LOGGER.info("Creating nic in resource group %s...", self._resource_group_name) poller = self._azure_service.network.network_interfaces.begin_create_or_update( resource_group_name=self._resource_group_name, network_interface_name=nic_name, parameters=parameters, ) pollers.append((nic_name, poller)) for nic_name, poller in pollers: poller.wait() nic = self._azure_service.network.network_interfaces.get( self._resource_group_name, nic_name) LOGGER.info("Provisioned nic %s in the %s resource group", nic.name, self._resource_group_name) self._cache[nic_name] = nic nics.append(nic) return nics def delete(self, nic: NetworkInterface): # just remove from cache as it should be deleted along with network interface del self._cache[nic.name] def clear_cache(self): self._cache = {} @staticmethod def get_nic_name(name: str): return f"{name}-nic"