Exemplo n.º 1
0
    def test_prod_staging_same(self) -> None:
        staging_yaml = YAMLDict.from_path(self.path_for_build_file("staging.yaml"))
        prod_yaml = YAMLDict.from_path(self.path_for_build_file("prod.yaml"))

        diff = deepdiff.DeepDiff(staging_yaml.get(), prod_yaml.get())

        # We expect the RECIDIVIZ_ENV values to be different
        env_diff = diff["values_changed"].pop("root['env_variables']['RECIDIVIZ_ENV']")
        self.assertEqual({"new_value": "production", "old_value": "staging"}, env_diff)

        # We expect the cloud sql instance names to be different, but names should match same pattern
        cloud_sql_instance_diff = diff["values_changed"].pop(
            "root['beta_settings']['cloud_sql_instances']"
        )
        staging_cloud_sql_instances: str = cloud_sql_instance_diff["old_value"]
        prod_cloud_sql_instances = cloud_sql_instance_diff["new_value"]
        self.assertEqual(
            (
                staging_cloud_sql_instances.replace(
                    "recidiviz-staging", "recidiviz-123"
                )  # Staging project becomes production
                .replace("dev-", "prod-")  # Dev prefix becomes prod
                .replace("-0af0a", "")  # Development case triage suffix is dropped
            ),
            prod_cloud_sql_instances,
        )

        # There should be no other values changed between the two
        self.assertFalse(diff.pop("values_changed"))
        # Aside from the few values changed, there should be no other changes
        self.assertFalse(diff)
    def _get_raw_data_file_configs(
            self) -> Dict[str, DirectIngestRawFileConfig]:
        """Returns list of file tags we expect to see on raw files for this region."""
        if os.path.isdir(self.yaml_config_file_dir):
            default_filename = f"{self.region_code}_default.yaml"
            default_file_path = os.path.join(self.yaml_config_file_dir,
                                             default_filename)
            if not os.path.exists(default_file_path):
                raise ValueError(
                    f"Missing default raw data configs for region: {self.region_code}"
                )

            default_contents = YAMLDict.from_path(default_file_path)
            default_encoding = default_contents.pop("default_encoding", str)
            default_separator = default_contents.pop("default_separator", str)

            raw_data_configs = {}
            for filename in os.listdir(self.yaml_config_file_dir):
                if filename == default_filename:
                    continue
                yaml_file_path = os.path.join(self.yaml_config_file_dir,
                                              filename)
                if os.path.isdir(yaml_file_path):
                    continue

                yaml_contents = YAMLDict.from_path(yaml_file_path)

                file_tag = yaml_contents.pop("file_tag", str)
                if not file_tag:
                    raise ValueError(f"Missing file_tag in [{yaml_file_path}]")
                if filename != f"{self.region_code.lower()}_{file_tag}.yaml":
                    raise ValueError(
                        f"Mismatched file_tag [{file_tag}] and filename [{filename}]"
                        f" in [{yaml_file_path}]")
                if file_tag in raw_data_configs:
                    raise ValueError(
                        f"Found file tag [{file_tag}] in [{yaml_file_path}]"
                        f" that is already defined in another yaml file.")

                raw_data_configs[
                    file_tag] = DirectIngestRawFileConfig.from_yaml_dict(
                        self.region_code,
                        file_tag,
                        yaml_file_path,
                        default_encoding,
                        default_separator,
                        yaml_contents,
                        filename,
                    )
        else:
            raise ValueError(
                f"Missing raw data configs for region: {self.region_code}")
        return raw_data_configs
    def _get_user_inputs(
        initialization_params: YAMLDict, model_params: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Helper to retrieve user_inputs for get_model_params"""

        user_inputs: Dict[str, Any] = dict()
        user_inputs_yaml_dict = initialization_params.pop_dict("user_inputs")

        if "big_query_simulation_tag" in model_params["data_inputs_raw"].keys():
            user_inputs = {
                "policy_year": user_inputs_yaml_dict.pop("policy_year", float),
                "start_year": user_inputs_yaml_dict.pop("start_year", float),
                "projection_years": user_inputs_yaml_dict.pop(
                    "projection_years", float
                ),
            }

        if "big_query_inputs" in model_params["data_inputs_raw"].keys():
            user_inputs = {
                "start_year": user_inputs_yaml_dict.pop("start_year", float),
                "projection_years": user_inputs_yaml_dict.pop(
                    "projection_years", float
                ),
                "run_date": user_inputs_yaml_dict.pop("run_date", str),
            }

        # Check for optional arguments
        if user_inputs_yaml_dict:
            user_inputs_keys = user_inputs_yaml_dict.keys()
            for k in user_inputs_keys:
                if k not in {"constant_admissions", "speed_run"}:
                    raise ValueError(f"Received unexpected key in user_inputs: {k}")
                user_inputs[k] = user_inputs_yaml_dict.pop(k, bool)

        return user_inputs
    def _get_valid_data_inputs(
            initialization_params: YAMLDict) -> Dict[str, Any]:
        """Helper to retrieve data_inputs for get_model_params"""

        given_data_inputs = initialization_params.pop_dict("data_inputs")
        if len(given_data_inputs) != 1:
            raise ValueError(
                f"Only one data input can be set in the yaml file, not {len(given_data_inputs)}"
            )

        data_inputs: Dict[str, Any] = dict()

        if "big_query_inputs" in given_data_inputs.keys():
            big_query_inputs_yaml_dict = given_data_inputs.pop_dict(
                "big_query_inputs")
            big_query_inputs_keys = big_query_inputs_yaml_dict.keys()

            big_query_inputs_dict: Dict[str, str] = dict()
            for k in big_query_inputs_keys:
                big_query_inputs_dict[k] = big_query_inputs_yaml_dict.pop(
                    k, str)

            data_inputs["big_query_inputs"] = big_query_inputs_dict
        elif "big_query_simulation_tag" in given_data_inputs.keys():
            data_inputs["big_query_simulation_tag"] = given_data_inputs.pop(
                "big_query_simulation_tag", str)
        else:
            raise ValueError(
                f"Received unexpected key in data_inputs: {given_data_inputs.keys()[0]}"
            )

        return data_inputs
    def get_model_params(cls, yaml_file_path: str) -> Dict[str, Any]:
        """Get the model parameters from the YAMLDict"""
        initialization_params = YAMLDict.from_path(yaml_file_path)

        cls._check_valid_yaml_inputs(initialization_params)

        model_params: Dict[str, Any] = dict()

        model_params["reference_year"] = initialization_params.pop(
            "reference_date", float)

        model_params["time_step"] = initialization_params.pop(
            "time_step", float)

        model_params["disaggregation_axes"] = initialization_params.pop(
            "disaggregation_axes", list)

        model_params["data_inputs_raw"] = cls._get_valid_data_inputs(
            initialization_params)

        model_params["user_inputs_raw"] = cls._get_user_inputs(
            initialization_params, model_params)

        (
            model_params["compartments_architecture"],
            model_params["compartment_costs"],
        ) = cls._get_valid_compartments(initialization_params)

        return model_params
    def _get_valid_compartments(
        initialization_params: YAMLDict,
    ) -> Tuple[Dict[str, str], Dict[str, float]]:
        """Helper to retrieve model_architecture and compartment costs for get_model_params"""

        compartments_architecture_yaml_key = "compartments_architecture"
        compartments_architecture_raw = initialization_params.pop_dict(
            compartments_architecture_yaml_key
        )
        compartments_architecture_keys = compartments_architecture_raw.keys()

        compartments_architecture_dict: Dict[str, Any] = dict()
        for k in compartments_architecture_keys:
            compartments_architecture_dict[
                k
            ] = compartments_architecture_raw.pop_optional(k, str)

        compartment_costs_key = "per_year_costs"
        compartment_costs_raw = initialization_params.pop_dict(compartment_costs_key)
        compartment_costs_keys = compartment_costs_raw.keys()

        compartment_costs_dict: Dict[str, float] = dict()
        for k in compartment_costs_keys:
            compartment_costs_dict[k] = compartment_costs_raw.pop(k, float)

        # Ensure there are compartment costs for every compartment in the model architecture
        model_compartments = set(
            c
            for c in compartments_architecture_keys
            if compartments_architecture_dict[c] != "shell"
        )
        compartment_costs = set(compartment_costs_keys)
        if compartment_costs != model_compartments:
            raise ValueError(
                f"Compartments do not match in the YAML '{compartment_costs_key}' "
                f"and '{compartments_architecture_yaml_key}'\n"
                f"Mismatched values: {compartment_costs ^ model_compartments}"
            )

        return compartments_architecture_dict, compartment_costs_dict
Exemplo n.º 7
0
def _pipeline_regions_by_job_name() -> Dict[str, str]:
    """Parses the production_calculation_pipeline_templates.yaml config file to determine
    which region a pipeline should be run in."""
    daily_pipelines = YAMLDict.from_path(PRODUCTION_TEMPLATES_PATH).pop_dicts(
        "daily_pipelines"
    )
    historical_pipelines = YAMLDict.from_path(PRODUCTION_TEMPLATES_PATH).pop_dicts(
        "historical_pipelines"
    )

    pipeline_regions = {
        pipeline.pop("job_name", str): pipeline.pop("region", str)
        for pipeline in daily_pipelines
    }

    pipeline_regions.update(
        {
            pipeline.pop("job_name", str): pipeline.pop("region", str)
            for pipeline in historical_pipelines
        }
    )

    return pipeline_regions
def _get_month_range_for_metric_and_state() -> Dict[str, Dict[str, int]]:
    """Determines the maximum number of months that each metric is calculated regularly
    for each state.

    Returns a dictionary in the format: {
        metric_table: {
                        state_code: int,
                        state_code: int
                      }
        }
    where the int values are the number of months for which the metric is regularly
    calculated for that state.
    """
    # Map metric type enum values to the corresponding tables in BigQuery
    metric_type_to_table: Dict[str, str] = {
        metric_type.value: table
        for table, metric_type in
        dataflow_config.DATAFLOW_TABLES_TO_METRIC_TYPES.items()
    }

    all_pipelines = YAMLDict.from_path(
        dataflow_config.PRODUCTION_TEMPLATES_PATH)
    daily_pipelines = all_pipelines.pop_dicts("daily_pipelines")
    historical_pipelines = all_pipelines.pop_dicts("historical_pipelines")

    # Dict with the format: {metric_table: {state_code: int}}
    month_range_for_metric_and_state: Dict[str, Dict[str, int]] = defaultdict(
        lambda: defaultdict(int))

    for pipeline_config_group in [daily_pipelines, historical_pipelines]:
        for pipeline_config in pipeline_config_group:
            if (pipeline_config.pop("pipeline", str)
                    in dataflow_config.ALWAYS_UNBOUNDED_DATE_PIPELINES):
                # This pipeline is always run in full, and is handled separately
                continue

            metrics = pipeline_config.pop("metric_types", str)
            calculation_month_count = pipeline_config.pop(
                "calculation_month_count", int)
            state_code = pipeline_config.pop("state_code", str)

            for metric in metrics.split(" "):
                metric_table = metric_type_to_table[metric]
                current_max = month_range_for_metric_and_state[metric_table][
                    state_code]
                month_range_for_metric_and_state[metric_table][
                    state_code] = max(current_max, calculation_month_count)

    return month_range_for_metric_and_state
Exemplo n.º 9
0
    def default_config(self) -> DirectIngestRawFileDefaultConfig:
        default_filename = f"{self.region_code.lower()}_default.yaml"
        default_file_path = os.path.join(self.yaml_config_file_dir,
                                         default_filename)
        if not os.path.exists(default_file_path):
            raise ValueError(
                f"Missing default raw data configs for region: {self.region_code}. "
                f"None found at path: [{default_file_path}]")
        default_contents = YAMLDict.from_path(default_file_path)
        default_encoding = default_contents.pop("default_encoding", str)
        default_separator = default_contents.pop("default_separator", str)
        default_ignore_quotes = default_contents.pop("default_ignore_quotes",
                                                     bool)

        return DirectIngestRawFileDefaultConfig(
            filename=default_filename,
            default_encoding=default_encoding,
            default_separator=default_separator,
            default_ignore_quotes=default_ignore_quotes,
        )
    def _check_valid_yaml_inputs(initialization_params: YAMLDict) -> None:
        # Make sure only one input setting is provided in the yaml file

        required_inputs = {
            "user_inputs",
            "compartments_architecture",
            "reference_date",
            "time_step",
            "data_inputs",
            "disaggregation_axes",
            "per_year_costs",
        }
        given_inputs = set(initialization_params.keys())

        missing_inputs = required_inputs.difference(given_inputs)
        if len(missing_inputs) > 0:
            raise ValueError(f"Missing yaml inputs: {missing_inputs}")

        unexpected_inputs = given_inputs.difference(required_inputs)
        if len(unexpected_inputs) > 0:
            raise ValueError(f"Unexpected yaml inputs: {unexpected_inputs}")
    def from_yaml(cls, yaml_path: str) -> "DatasetSchemaInfo":
        yaml_contents = YAMLDict.from_path(yaml_path)

        dataset = yaml_contents.pop("dataset", str)

        yaml_tables = yaml_contents.pop_dicts("tables")
        tables: List[TableSchemaInfo] = []

        for table_dict in yaml_tables:
            table_name = table_dict.pop("name", str)
            columns = table_dict.pop("columns", list)
            tables.append(
                TableSchemaInfo(
                    table_name=table_name,
                    columns=sorted([c.lower() for c in columns]),
                ))

        return DatasetSchemaInfo(
            dataset=dataset,
            tables=tables,
        )
Exemplo n.º 12
0
    def from_file(cls, path: str = PRODUCTS_CONFIG_PATH) -> "ProductConfigs":
        """Reads a product config file and returns a list of corresponding ProductConfig objects."""

        product_config = YAMLDict.from_path(path).pop("products", list)
        products = [
            ProductConfig(
                name=product["name"],
                description=product["description"],
                exports=product["exports"],
                states=[
                    ProductStateConfig(
                        state_code=state["state_code"], environment=state["environment"]
                    )
                    for state in product["states"]
                ]
                if "states" in product
                else None,
                environment=product.get("environment"),
                is_state_agnostic=product.get("is_state_agnostic", False),
            )
            for product in product_config
        ]
        return cls(products=products)
Exemplo n.º 13
0
def _validate_yaml(yaml_path: str, uploads: List[Dict[str, Any]]) -> None:
    "Validate the contents of the relevant yaml file"

    yaml_dict = YAMLDict.from_path(yaml_path)

    # Check for all required and no extra inputs
    required_inputs = {
        "user_inputs",
        "compartments_architecture",
        "reference_date",
        "time_step",
        "data_inputs",
        "disaggregation_axes",
        "per_year_costs",
    }
    given_inputs = set(yaml_dict.keys())

    missing_inputs = required_inputs.difference(given_inputs)
    if len(missing_inputs) > 0:
        raise ValueError(f"Missing yaml inputs: {missing_inputs}")

    unexpected_inputs = given_inputs.difference(required_inputs)
    if len(unexpected_inputs) > 0:
        raise ValueError(f"Unexpected yaml inputs: {unexpected_inputs}")

    # Check that all disaggregation axes are in all the dataframes
    disaggregation_axes = yaml_dict.pop("disaggregation_axes", list)

    for axis in disaggregation_axes:
        for upload in uploads:
            if upload["table"] != "total_population_data_raw":
                df = upload["data_df"]
                if axis not in df.columns:
                    raise ValueError(
                        f"All disagregation axes must be included in the input dataframe columns\n"
                        f"Expected: {disaggregation_axes}, Actual: {df.columns}"
                    )
Exemplo n.º 14
0
 def test_travis_yaml_parses(self) -> None:
     yaml_dict = YAMLDict.from_path(self.path_for_build_file(".travis.yml"))
     self.assertTrue(yaml_dict.get())
Exemplo n.º 15
0
 def test_staging_yaml_parses(self) -> None:
     yaml_dict = YAMLDict.from_path(self.path_for_build_file("staging.yaml"))
     self.assertTrue(yaml_dict.get())
Exemplo n.º 16
0
def compare_dataflow_output_to_sandbox(
    sandbox_dataset_prefix: str,
    job_name_to_compare: str,
    base_output_job_id: str,
    sandbox_output_job_id: str,
    additional_columns_to_compare: List[str],
    allow_overwrite: bool = False,
) -> None:
    """Compares the output for all metrics produced by the daily pipeline job with the given |job_name_to_compare|
    between the output from the |base_output_job_id| job in the dataflow_metrics dataset and the output from the
    |sandbox_output_job_id| job in the sandbox dataflow dataset."""
    bq_client = BigQueryClientImpl()
    sandbox_dataflow_dataset_id = (sandbox_dataset_prefix + "_" +
                                   DATAFLOW_METRICS_DATASET)

    sandbox_comparison_output_dataset_id = (sandbox_dataset_prefix +
                                            "_dataflow_comparison_output")
    sandbox_comparison_output_dataset_ref = bq_client.dataset_ref_for_id(
        sandbox_comparison_output_dataset_id)

    if bq_client.dataset_exists(sandbox_comparison_output_dataset_ref) and any(
            bq_client.list_tables(sandbox_comparison_output_dataset_id)):
        if not allow_overwrite:
            if __name__ == "__main__":
                logging.error(
                    "Dataset %s already exists in project %s. To overwrite, set --allow_overwrite.",
                    sandbox_comparison_output_dataset_id,
                    bq_client.project_id,
                )
                sys.exit(1)
            else:
                raise ValueError(
                    f"Cannot write comparison output to a non-empty dataset. Please delete tables in dataset: "
                    f"{bq_client.project_id}.{sandbox_comparison_output_dataset_id}."
                )
        else:
            # Clean up the existing tables in the dataset
            for table in bq_client.list_tables(
                    sandbox_comparison_output_dataset_id):
                bq_client.delete_table(table.dataset_id, table.table_id)

    bq_client.create_dataset_if_necessary(
        sandbox_comparison_output_dataset_ref,
        TEMP_DATASET_DEFAULT_TABLE_EXPIRATION_MS)

    query_jobs: List[Tuple[QueryJob, str]] = []

    pipelines = YAMLDict.from_path(PRODUCTION_TEMPLATES_PATH).pop_dicts(
        "daily_pipelines")

    for pipeline in pipelines:
        if pipeline.pop("job_name", str) == job_name_to_compare:
            pipeline_metric_types = pipeline.peek_optional("metric_types", str)

            if not pipeline_metric_types:
                raise ValueError(
                    f"Pipeline job {job_name_to_compare} missing required metric_types attribute."
                )

            metric_types_for_comparison = pipeline_metric_types.split()

            for metric_class, metric_table in DATAFLOW_METRICS_TO_TABLES.items(
            ):
                metric_type_value = DATAFLOW_TABLES_TO_METRIC_TYPES[
                    metric_table].value

                if metric_type_value in metric_types_for_comparison:
                    comparison_query = _query_for_metric_comparison(
                        bq_client,
                        base_output_job_id,
                        sandbox_output_job_id,
                        sandbox_dataflow_dataset_id,
                        metric_class,
                        metric_table,
                        additional_columns_to_compare,
                    )

                    query_job = bq_client.create_table_from_query_async(
                        dataset_id=sandbox_comparison_output_dataset_id,
                        table_id=metric_table,
                        query=comparison_query,
                        overwrite=True,
                    )

                    # Add query job to the list of running jobs
                    query_jobs.append((query_job, metric_table))

    for query_job, output_table_id in query_jobs:
        # Wait for the insert job to complete before looking for the table
        query_job.result()

        output_table = bq_client.get_table(
            sandbox_comparison_output_dataset_ref, output_table_id)

        if output_table.num_rows == 0:
            # If there are no rows in the output table, then the output was identical
            bq_client.delete_table(sandbox_comparison_output_dataset_id,
                                   output_table_id)

    metrics_with_different_output = peekable(
        bq_client.list_tables(sandbox_comparison_output_dataset_id))

    logging.info(
        "\n*************** DATAFLOW OUTPUT COMPARISON RESULTS ***************\n"
    )

    if metrics_with_different_output:
        for metric_table in metrics_with_different_output:
            # This will always be true, and is here to silence mypy warnings
            assert isinstance(metric_table, bigquery.table.TableListItem)

            logging.warning(
                "Dataflow output differs for metric %s. See %s.%s for diverging rows.",
                metric_table.table_id,
                sandbox_comparison_output_dataset_id,
                metric_table.table_id,
            )
    else:
        logging.info(
            "Dataflow output identical. Deleting dataset %s.",
            sandbox_comparison_output_dataset_ref.dataset_id,
        )
        bq_client.delete_dataset(sandbox_comparison_output_dataset_ref,
                                 delete_contents=True)
    def __init__(
        self,
        products: List[ProductConfig],
        root_calc_docs_dir: str,
    ):
        self.root_calc_docs_dir = root_calc_docs_dir
        self.products = products

        self.states_by_product = self.get_states_by_product()

        # Reverses the states_by_product dictionary
        self.products_by_state: Dict[StateCode, Dict[
            GCPEnvironment,
            List[ProductName]]] = defaultdict(lambda: defaultdict(list))

        for product_name, environments_to_states in self.states_by_product.items(
        ):
            for environment, states in environments_to_states.items():
                for state in states:
                    self.products_by_state[state][environment].append(
                        product_name)
        self.dag_walker = BigQueryViewDagWalker(
            _build_views_to_update(
                view_source_table_datasets=VIEW_SOURCE_TABLE_DATASETS,
                candidate_view_builders=DEPLOYED_VIEW_BUILDERS,
                dataset_overrides=None,
                override_should_build_predicate=True,
            ))
        self.prod_templates_yaml = YAMLDict.from_path(
            PRODUCTION_TEMPLATES_PATH)

        self.daily_pipelines = self.prod_templates_yaml.pop_dicts(
            "daily_pipelines")
        self.historical_pipelines = self.prod_templates_yaml.pop_dicts(
            "historical_pipelines")

        self.metric_calculations_by_state = self._get_state_metric_calculations(
            self.daily_pipelines, "daily")
        # combine with the historical pipelines
        for name, metric_info_list in self._get_state_metric_calculations(
                self.historical_pipelines,
                "triggered by code changes").items():
            self.metric_calculations_by_state[name].extend(metric_info_list)

        # Reverse the metric_calculations_by_state dictionary
        self.state_metric_calculations_by_metric: Dict[
            str, List[StateMetricInfo]] = defaultdict(list)
        for state_name, metric_info_list in self.metric_calculations_by_state.items(
        ):
            for metric_info in metric_info_list:
                self.state_metric_calculations_by_metric[
                    metric_info.name].append(
                        StateMetricInfo(
                            name=state_name,
                            month_count=metric_info.month_count,
                            frequency=metric_info.frequency,
                        ))

        self.metrics_by_generic_types = self._get_metrics_by_generic_types()

        self.generic_types_by_metric_name = {}
        for generic_type, metric_list in self.metrics_by_generic_types.items():
            for metric in metric_list:
                self.generic_types_by_metric_name[
                    DATAFLOW_METRICS_TO_TABLES[metric]] = generic_type

        def _preprocess_views(
                v: BigQueryView, _parent_results: Dict[BigQueryView,
                                                       None]) -> None:
            dag_key = DagKey(view_address=v.address)
            node = self.dag_walker.nodes_by_key[dag_key]

            # Fills out full child/parent dependencies and tree representations for use
            # in various sections.
            self.dag_walker.populate_node_family_for_node(
                node=node,
                datasets_to_skip={DATAFLOW_METRICS_MATERIALIZED_DATASET}
                | RAW_TABLE_DATASETS,
                custom_node_formatter=self.
                _dependency_tree_formatter_for_gitbook,
                view_source_table_datasets=VIEW_SOURCE_TABLE_DATASETS
                | LATEST_VIEW_DATASETS,
            )

        self.dag_walker.process_dag(_preprocess_views)
        self.all_views_to_document = self._get_all_views_to_document()
Exemplo n.º 18
0
    def from_yaml_dict(
        cls,
        file_tag: str,
        file_path: str,
        default_encoding: str,
        default_separator: str,
        default_ignore_quotes: bool,
        file_config_dict: YAMLDict,
        yaml_filename: str,
    ) -> "DirectIngestRawFileConfig":
        """Returns a DirectIngestRawFileConfig built from a YAMLDict"""
        primary_key_cols = file_config_dict.pop("primary_key_cols", list)
        file_description = file_config_dict.pop("file_description", str)
        columns = file_config_dict.pop("columns", list)

        column_names = [column["name"] for column in columns]
        if len(column_names) != len(set(column_names)):
            raise ValueError(
                f"Found duplicate columns in raw_file [{file_tag}]")

        missing_columns = set(primary_key_cols) - {
            column["name"]
            for column in columns
        }
        if missing_columns:
            raise ValueError(
                f"Column(s) marked as primary keys not listed in"
                f" columns list for file [{yaml_filename}]: {missing_columns}")

        supplemental_order_by_clause = file_config_dict.pop_optional(
            "supplemental_order_by_clause", str)
        encoding = file_config_dict.pop_optional("encoding", str)
        separator = file_config_dict.pop_optional("separator", str)
        ignore_quotes = file_config_dict.pop_optional("ignore_quotes", bool)
        custom_line_terminator = file_config_dict.pop_optional(
            "custom_line_terminator", str)
        always_historical_export = file_config_dict.pop_optional(
            "always_historical_export", bool)

        if len(file_config_dict) > 0:
            raise ValueError(f"Found unexpected config values for raw file"
                             f"[{file_tag}]: {repr(file_config_dict.get())}")
        return DirectIngestRawFileConfig(
            file_tag=file_tag,
            file_path=file_path,
            file_description=file_description,
            primary_key_cols=primary_key_cols,
            columns=[
                RawTableColumnInfo(
                    name=column["name"],
                    is_datetime=column.get("is_datetime", False),
                    description=column.get("description", None),
                    known_values=[
                        ColumnEnumValueInfo(
                            value=str(x["value"]),
                            description=x.get("description", None),
                        ) for x in column["known_values"]
                    ] if "known_values" in column else None,
                ) for column in columns
            ],
            supplemental_order_by_clause=supplemental_order_by_clause
            if supplemental_order_by_clause else "",
            encoding=encoding if encoding else default_encoding,
            separator=separator if separator else default_separator,
            custom_line_terminator=custom_line_terminator,
            ignore_quotes=ignore_quotes
            if ignore_quotes else default_ignore_quotes,
            always_historical_export=always_historical_export
            if always_historical_export else False,
        )
    def from_yaml_dict(
        cls,
        region_code: str,
        file_tag: str,
        file_path: str,
        default_encoding: str,
        default_separator: str,
        file_config_dict: YAMLDict,
        yaml_filename: str,
    ) -> "DirectIngestRawFileConfig":
        """Returns a DirectIngestRawFileConfig built from a YAMLDict"""
        primary_key_cols = file_config_dict.pop("primary_key_cols", list)
        # TODO(#5399): Migrate raw file configs for all legacy regions to have file descriptions
        if region_code.upper() in {"US_PA"}:
            file_description = (file_config_dict.pop_optional(
                "file_description", str) or "LEGACY_FILE_MISSING_DESCRIPTION")
        else:
            file_description = file_config_dict.pop("file_description", str)
        # TODO(#5399): Migrate raw file configs for all legacy regions to have column descriptions
        if region_code.upper() in {"US_PA"}:
            columns = file_config_dict.pop_optional("columns", list) or []
        else:
            columns = file_config_dict.pop("columns", list)

        column_names = [column["name"] for column in columns]
        if len(column_names) != len(set(column_names)):
            raise ValueError(
                f"Found duplicate columns in raw_file [{file_tag}]")

        missing_columns = set(primary_key_cols) - {
            column["name"]
            for column in columns
        }
        # TODO(#5399): Remove exempted region codes once legacy primary keys are documented
        if missing_columns and region_code.upper() not in {"US_PA"}:
            raise ValueError(
                f"Column(s) marked as primary keys not listed in"
                f" columns list for file [{yaml_filename}]: {missing_columns}")

        supplemental_order_by_clause = file_config_dict.pop_optional(
            "supplemental_order_by_clause", str)
        encoding = file_config_dict.pop_optional("encoding", str)
        separator = file_config_dict.pop_optional("separator", str)
        ignore_quotes = file_config_dict.pop_optional("ignore_quotes", bool)
        always_historical_export = file_config_dict.pop_optional(
            "always_historical_export", bool)

        if len(file_config_dict) > 0:
            raise ValueError(f"Found unexpected config values for raw file"
                             f"[{file_tag}]: {repr(file_config_dict.get())}")

        return DirectIngestRawFileConfig(
            file_tag=file_tag,
            file_path=file_path,
            file_description=file_description,
            primary_key_cols=primary_key_cols,
            columns=[
                RawTableColumnInfo(
                    name=column["name"],
                    is_datetime=column.get("is_datetime", False),
                    description=column.get("description", None),
                ) for column in columns
            ],
            supplemental_order_by_clause=supplemental_order_by_clause
            if supplemental_order_by_clause else "",
            encoding=encoding if encoding else default_encoding,
            separator=separator if separator else default_separator,
            ignore_quotes=ignore_quotes if ignore_quotes else False,
            always_historical_export=always_historical_export
            if always_historical_export else False,
        )
Exemplo n.º 20
0
 def test_circleci_yaml_parses(self) -> None:
     yaml_dict = YAMLDict.from_path(self.path_for_build_file(".circleci/config.yml"))
     self.assertTrue(yaml_dict.get())