def validate_num_classes(vector_file: Union[str, Path], num_classes: int, attribute_name: str, ignore_index: int): """Check that `num_classes` is equal to number of classes detected in the specified attribute for each GeoPackage. FIXME: this validation **will not succeed** if a Geopackage contains only a subset of `num_classes` (e.g. 3 of 4). Args: vector_file: full file path of the vector image num_classes: number of classes set in config_template.yaml attribute_name: name of the value field representing the required classes in the vector image file ignore_index: (int) target value that is ignored during training and does not contribute to the input gradient Return: List of unique attribute values found in gpkg vector file """ distinct_att = set() with fiona.open(vector_file, 'r') as src: for feature in tqdm(src, leave=False, position=1, desc=f'Scanning features'): distinct_att.add(get_key_recursive( attribute_name, feature)) # Use property of set to store unique values detected_classes = len(distinct_att) - len( [ignore_index]) if ignore_index in distinct_att else len(distinct_att) if detected_classes != num_classes: raise ValueError( 'The number of classes in the yaml.config {} is different than the number of classes in ' 'the file {} {}'.format(num_classes, vector_file, str(list(distinct_att)))) return distinct_att
def validate_num_classes(vector_file: Union[str, Path], num_classes: int, attribute_name: str, ignore_index: int, target_ids: List): """Check that `num_classes` is equal to number of classes detected in the specified attribute for each GeoPackage. FIXME: this validation **will not succeed** if a Geopackage contains only a subset of `num_classes` (e.g. 3 of 4). Args: :param vector_file: full file path of the vector image :param num_classes: number of classes set in config_template.yaml :param attribute_name: name of the value field representing the required classes in the vector image file :param ignore_index: (int) target value that is ignored during training and does not contribute to the input gradient :param target_ids: list of identifiers to burn from the vector file (None = use all) Return: List of unique attribute values found in gpkg vector file """ if isinstance(vector_file, str): vector_file = Path(vector_file) if not vector_file.is_file(): raise FileNotFoundError(f"Could not locate gkpg file at {vector_file}") unique_att_vals = set() with fiona.open(vector_file, 'r') as src: for feature in tqdm(src, leave=False, position=1, desc=f'Scanning features'): # Use property of set to store unique values unique_att_vals.add(int(get_key_recursive(attribute_name, feature))) # if dontcare value is defined, remove from list of unique attribute values for verification purposes if ignore_index in unique_att_vals: unique_att_vals.remove(ignore_index) # if burning a subset of gpkg's classes if target_ids: if not len(target_ids) == num_classes: raise ValueError( f'Yaml parameters mismatch. \n' f'Got target_ids {target_ids} (sample sect) with length {len(target_ids)}. ' f'Expected match with num_classes {num_classes} (global sect))' ) # make sure target ids are a subset of all attribute values in gpkg if not set(target_ids).issubset(unique_att_vals): logging.warning( f'\nFailed scan of vector file: {vector_file}\n' f'\tExpected to find all target ids {target_ids}. \n' f'\tFound {unique_att_vals} for attribute "{attribute_name}"') else: # this can happen if gpkg doens't contain all classes, thus the warning rather than exception if len(unique_att_vals) < num_classes: logging.warning( f'Found {str(list(unique_att_vals))} classes in file {vector_file}. Expected {num_classes}' ) # this should not happen, thus the exception raised elif len(unique_att_vals) > num_classes: raise ValueError( f'Found {str(list(unique_att_vals))} classes in file {vector_file}. Expected {num_classes}' ) return unique_att_vals