def with_checker(*args, **kwargs): # Get params cluster_name = kwargs['cluster_name'] # Get details try: cluster_details = load_cluster_details(cluster_name=cluster_name) except FileNotFoundError: raise CliException(f"Cluster {cluster_name} is not found") # Check details validity try: if mode == 'grass' and cluster_details['mode'] == 'grass': if cluster_details['cloud']['infra'] == 'azure': pass else: raise ParsingError(f"Details are broken: Invalid infra: {cluster_details['cloud']['infra']}") elif mode == 'k8s' and cluster_details['mode'] == 'k8s': if cluster_details['cloud']['infra'] == 'azure': pass else: raise ParsingError(f"Details are broken: Invalid infra: {cluster_details['cloud']['infra']}") else: raise ParsingError(f"Details are broken: Invalid mode: {cluster_details['mode']}") except KeyError as e: raise ParsingError(f"Details are broken: Missing key: '{e.args[0]}'") func(*args, **kwargs)
def scale_node(self, replicas: int, node_size: str): # Load details nodes_details = self.grass_executor.remote_get_nodes_details() # Init node_size_to_count node_size_to_count = collections.defaultdict(lambda: 0) for node_name, node_details in nodes_details.items(): node_size_to_count[node_details['node_size']] += 1 # Get node_size_to_spec node_size_to_spec = self._get_node_size_to_spec() if node_size not in node_size_to_spec: raise CliException(f"Invalid node_size {node_size}") # Scale nodes if node_size_to_count[node_size] > replicas: self._delete_nodes( num=node_size_to_count[node_size] - replicas, node_size=node_size ) elif node_size_to_count[node_size] < replicas: self._create_nodes( num=replicas - node_size_to_count[node_size], node_size=node_size, node_size_to_spec=node_size_to_spec ) else: logger.warning_yellow("Replica is match, no create or delete")
def status(self, resource_name: str): if resource_name == "master": return_status = self.grass_executor.remote_get_master_details() elif resource_name == "nodes": return_status = self.grass_executor.remote_get_nodes_details() else: raise CliException(f"Resource {resource_name} is unsupported") # Print status logger.info(json.dumps(return_status, indent=4, sort_keys=True))
def push_image(self, image_name: str, image_path: str, remote_context_path: str, remote_image_name: str): # Load details cluster_details = self.cluster_details admin_username = cluster_details['user']['admin_username'] master_public_ip_address = cluster_details['master'][ 'public_ip_address'] # Get images dir images_dir = f"{GlobalPaths.MARO_CLUSTERS}/{self.cluster_name}/images" # Push image if image_name: new_file_name = get_valid_file_name(image_name) image_path = f"{GlobalPaths.MARO_CLUSTERS}/{self.cluster_name}/images/{new_file_name}" self._save_image(image_name=image_name, export_path=os.path.expanduser(image_path)) if self._check_checksum_validity( local_file_path=os.path.expanduser(image_path), remote_file_path=os.path.join(images_dir, image_name)): logger.info_green( f"The image file '{new_file_name}' already exists") return copy_files_to_node(local_path=image_path, remote_dir=images_dir, admin_username=admin_username, node_ip_address=master_public_ip_address) self.grass_executor.remote_update_image_files_details() self._batch_load_images() logger.info_green(f"Image {image_name} is loaded") elif image_path: file_name = os.path.basename(image_path) new_file_name = get_valid_file_name(file_name) image_path = f"{GlobalPaths.MARO_CLUSTERS}/{self.cluster_name}/images/{new_file_name}" copy_and_rename(source_path=os.path.expanduser(image_path), target_dir=image_path) if self._check_checksum_validity( local_file_path=os.path.expanduser(image_path), remote_file_path=os.path.join(images_dir, new_file_name)): logger.info_green( f"The image file '{new_file_name}' already exists") return copy_files_to_node(local_path=image_path, remote_dir=images_dir, admin_username=admin_username, node_ip_address=master_public_ip_address) self.grass_executor.remote_update_image_files_details() self._batch_load_images() elif remote_context_path and remote_image_name: self.grass_executor.remote_build_image( remote_context_path=remote_context_path, remote_image_name=remote_image_name) self._batch_load_images() else: raise CliException("Invalid arguments")
def validate_and_fill_dict(template_dict: dict, actual_dict: dict, optional_key_to_value: dict): deep_diff = DeepDiff(template_dict, actual_dict).to_dict() missing_keys = deep_diff.get('dictionary_item_removed', []) for key in missing_keys: if key not in optional_key_to_value: raise CliException(f"Invalid deployment: key {key} not found") else: set_in_dict(actual_dict, get_map_list(deep_diff_str=key), optional_key_to_value[key])
def build_cluster_details(create_deployment: dict): # Validate and fill optional value to deployment K8sAzureExecutor._validate_create_deployment(create_deployment=create_deployment) # Get cluster name and save details cluster_name = create_deployment['name'] if os.path.isdir(os.path.expanduser(f"{GlobalPaths.MARO_CLUSTERS}/{cluster_name}")): raise CliException(f"cluster {cluster_name} is exist") os.makedirs(os.path.expanduser(f"{GlobalPaths.MARO_CLUSTERS}/{cluster_name}")) save_cluster_details( cluster_name=cluster_name, cluster_details=create_deployment )
def build_cluster_details(create_deployment: dict): # Standardize create deployment GrassAzureExecutor._standardize_create_deployment(create_deployment=create_deployment) # Get cluster name and save details cluster_name = create_deployment['name'] if os.path.isdir(os.path.expanduser(f"{GlobalPaths.MARO_CLUSTERS}/{cluster_name}")): raise CliException(f"Cluster {cluster_name} is exist") os.makedirs(os.path.expanduser(f"{GlobalPaths.MARO_CLUSTERS}/{cluster_name}")) save_cluster_details( cluster_name=cluster_name, cluster_details=create_deployment )
def pull_data(cluster_name: str, local_path: str, remote_path: str, **kwargs): # Load details cluster_details = load_cluster_details(cluster_name=cluster_name) admin_username = cluster_details['user']['admin_username'] master_public_ip_address = cluster_details['master']['public_ip_address'] if not remote_path.startswith("/"): raise CliException("Invalid remote path") copy_files_from_node( local_dir=local_path, remote_path= f"{GlobalPaths.MARO_CLUSTERS}/{cluster_name}/data{remote_path}", admin_username=admin_username, node_ip_address=master_public_ip_address)
def retry_until_connected(self, node_ip_address: str) -> bool: remain_retries = 10 while remain_retries > 0: try: self.test_connection(node_ip_address) return True except CliException: remain_retries -= 1 logger.debug( f"Unable to connect to {node_ip_address}, remains {remain_retries} retries" ) time.sleep(10) continue raise CliException(f"Unable to connect to {node_ip_address}")
def scale_node(self, replicas: int, node_size: str): # Get node_size_to_info node_size_to_info = self._get_node_size_to_info() # Get node_size_to_spec, and check if node_size is valid node_size_to_spec = self._get_node_size_to_spec() if node_size not in node_size_to_spec: raise CliException(f"Invalid node_size: {node_size}") # Scale node if node_size not in node_size_to_info: self._build_node_pool(replicas=replicas, node_size=node_size) elif node_size_to_info[node_size]['count'] != replicas: self._scale_node_pool(replicas=replicas, node_size=node_size, node_size_to_info=node_size_to_info) else: logger.warning_yellow("Replica is match, no create or delete")
def push_data(self, local_path: str, remote_dir: str): # Load details cluster_details = self.cluster_details cluster_id = cluster_details['id'] # Get sas sas = self._check_and_get_account_sas() # Push data source_path = get_reformatted_source_path(local_path) target_dir = get_reformatted_target_dir(remote_dir) if not target_dir.startswith("/"): raise CliException("Invalid remote path") copy_command = f'azcopy copy ' \ f'"{source_path}" ' \ f'"https://{cluster_id}st.file.core.windows.net/{cluster_id}-fs{target_dir}?{sas}" ' \ f'--recursive=True' _ = SubProcess.run(copy_command)
def start_node(self, replicas: int, node_size: str): # Get nodes details nodes_details = self.grass_executor.remote_get_nodes_details() # Get startable nodes startable_nodes = [] for node_name, node_details in nodes_details.items(): if node_details['node_size'] == node_size and node_details[ 'state'] == 'Stopped': startable_nodes.append(node_name) # Check replicas if len(startable_nodes) < replicas: raise CliException(f"No enough {node_size} nodes can be started") # Parallel start params = [[startable_node] for startable_node in startable_nodes[:replicas]] with ThreadPool(GlobalParams.PARALLELS) as pool: pool.starmap(self._start_node, params)
def copy_and_rename(source_path: str, target_dir: str, new_name: str = None): """Copy and rename a file. Args: source_path (str): path of the source target_dir (str): dir of the target new_name (str): name of the new file, if None, will not do rename """ source_path = os.path.expanduser(source_path) target_dir = os.path.expanduser(target_dir) if os.path.isdir(source_path): raise CliException("Invalid file path: cannot be a folder") shutil.copy2(source_path, target_dir) if new_name is not None: old_name = os.path.basename(source_path) old_target_path = os.path.join(target_dir, old_name) new_target_path = os.path.join(target_dir, new_name) os.rename(old_target_path, new_target_path)
def pull_data(self, local_dir: str, remote_path: str): # Load details cluster_details = self.cluster_details cluster_id = cluster_details['id'] # Get sas sas = self._check_and_get_account_sas() # Push data local_dir = os.path.expanduser(local_dir) source_path = get_reformatted_source_path(remote_path) target_dir = get_reformatted_target_dir(local_dir) mkdir_script = f"mkdir -p {target_dir}" _ = SubProcess.run(mkdir_script) if not source_path.startswith("/"): raise CliException("Invalid remote path") copy_command = f'azcopy copy ' \ f'"https://{cluster_id}st.file.core.windows.net/{cluster_id}-fs{source_path}?{sas}" ' \ f'"{os.path.expanduser(target_dir)}" ' \ f'--recursive=True' _ = SubProcess.run(copy_command)