示例#1
0
def search_solutions_request(test_paths, specified_template=None):
    user_agent = "test_agent"
    version = core_pb2.DESCRIPTOR.GetOptions().Extensions[
        core_pb2.protocol_version]

    time_bound = 0.5
    priority = 10
    # allowed_value_types = [value_pb2.ValueType.Value(value) for value in ALLOWED_VALUE_TYPES]

    problem_description = utils.encode_problem_description(
        problem_module.Problem.load(test_paths['TRAIN']['problem']))

    template = None
    if specified_template == 'FULL':
        with d3m_utils.silence():
            pipeline = pipeline_utils.load_pipeline(
                FULL_SPECIFIED_PIPELINE_PATH)
        template = utils.encode_pipeline_description(
            pipeline, ALLOWED_VALUE_TYPES, constants.Path.TEMP_STORAGE_ROOT)
    elif specified_template == 'PRE':  # PRE for PREPROCESSING
        pipeline = runtime_module.get_pipeline(PRE_SPECIFIED_PIPELINE_PATH,
                                               load_all_primitives=False)
        template = utils.encode_pipeline_description(
            pipeline, ALLOWED_VALUE_TYPES, constants.Path.TEMP_STORAGE_ROOT)

    inputs = [value_pb2.Value(dataset_uri=test_paths['TRAIN']['dataset'])]

    request = core_pb2.SearchSolutionsRequest(
        user_agent=user_agent,
        version=version,
        time_bound_search=time_bound,
        priority=priority,
        allowed_value_types=ALLOWED_VALUE_TYPES,
        problem=problem_description,
        template=template,
        inputs=inputs)
    return request
示例#2
0
文件: schemas.py 项目: zwbjtu123/tods
def load_default_pipeline():
    from axolotl.utils import pipeline as pipeline_utils
    pipeline = pipeline_utils.load_pipeline(DEFAULT_PIPELINE_DIR)
    return pipeline
示例#3
0
def get_classification_pipeline():
    with open(schemas_utils.PIPELINES_DB_DIR) as file:
        default_pipelines = json.load(file)

    return pipeline_utils.load_pipeline(default_pipelines['CLASSIFICATION'][0])
示例#4
0
 def test_fit_lr(self):
     pipeline_info = os.path.join(os.path.dirname(__file__),  'resources', 'logistic_regeression.json')
     pipeline = pipeline_utils.load_pipeline(pipeline_info)
     _, pipeline_result = self.tuner_base.search_fit(input_data=[self.dataset], time_limit=60,
                                                     pipeline_candidates=[pipeline])
     self.assertEqual(pipeline_result.error, None)