コード例 #1
0
ファイル: utils_test.py プロジェクト: difince/pipelines
    def test_get_skip_evaluation_pipeline_and_parameters(self):
        _, parameter_values = utils.get_skip_evaluation_pipeline_and_parameters(
            'project', 'us-central1', 'gs://foo', 'target', 'classification',
            'maximize-au-prc', {'auto': {
                'column_name': 'feature_1'
            }}, {
                'fraction_split': {
                    'training_fraction': 0.8,
                    'validation_fraction': 0.2,
                    'test_fraction': 0.0
                }
            }, {'csv_data_source': {
                'csv_filenames': ['gs://foo/bar.csv']
            }}, 1000)

        expected_parameter_values = {
            'project': 'project',
            'location': 'us-central1',
            'root_dir': 'gs://foo',
            'target_column_name': 'target',
            'prediction_type': 'classification',
            'optimization_objective': 'maximize-au-prc',
            'transformations':
            '{\\"auto\\": {\\"column_name\\": \\"feature_1\\"}}',
            'split_spec':
            '{\\"fraction_split\\": {\\"training_fraction\\": 0.8, '
            '\\"validation_fraction\\": 0.2, \\"test_fraction\\": 0.0}}',
            'data_source': '{\\"csv_data_source\\": {\\"csv_filenames\\": '
            '[\\"gs://foo/bar.csv\\"]}}',
            'stage_1_deadline_hours': 0.7708333333333334,
            'stage_1_num_parallel_trials': 35,
            'stage_1_num_selected_trials': 7,
            'stage_1_single_run_max_secs': 634,
            'reduce_search_space_mode': 'minimal',
            'stage_2_deadline_hours': 0.22916666666666663,
            'stage_2_num_parallel_trials': 35,
            'stage_2_num_selected_trials': 5,
            'stage_2_single_run_max_secs': 634,
            'weight_column_name': '',
            'optimization_objective_recall_value': -1,
            'optimization_objective_precision_value': -1,
            'study_spec_override': '',
            'stage_1_tuner_worker_pool_specs_override': '',
            'cv_trainer_worker_pool_specs_override': '',
            'export_additional_model_without_custom_ops': False,
            'stats_and_example_gen_dataflow_machine_type': 'n1-standard-16',
            'stats_and_example_gen_dataflow_max_num_workers': 25,
            'stats_and_example_gen_dataflow_disk_size_gb': 40,
            'transform_dataflow_machine_type': 'n1-standard-16',
            'transform_dataflow_max_num_workers': 25,
            'transform_dataflow_disk_size_gb': 40,
            'encryption_spec_key_name': '',
            'dataflow_subnetwork': '',
            'dataflow_use_public_ips': True,
        }
        self.assertEqual(parameter_values, expected_parameter_values)
コード例 #2
0
ファイル: utils_test.py プロジェクト: kubeflow/pipelines
    def test_get_skip_evaluation_pipeline_and_parameters(self):
        _, parameter_values = utils.get_skip_evaluation_pipeline_and_parameters(
            'project', 'us-central1', 'gs://foo', 'target', 'classification',
            'maximize-au-prc', {'auto': {
                'column_name': 'feature_1'
            }}, {
                'fraction_split': {
                    'training_fraction': 0.8,
                    'validation_fraction': 0.2,
                    'test_fraction': 0.0
                }
            }, {'csv_data_source': {
                'csv_filenames': ['gs://foo/bar.csv']
            }}, 1000)

        self.assertEqual(parameter_values, self.parameter_values)