Ejemplo n.º 1
0
def read_component_model(component_model_key, party_model_id, model_version):
    pipeline_model_table = storage.table(
        name=model_version,
        namespace=party_model_id,
        partition=get_model_table_partition_count(),
        create_if_missing=False,
        error_if_exist=False)
    model_buffers = {}
    if pipeline_model_table:
        model_class_map = storage.get_data_table_metas_by_instance(
            data_table=pipeline_model_table)
        for storage_key, buffer_object_bytes in pipeline_model_table.collect(
                use_serialize=False):
            storage_key_items = storage_key.split(':')
            buffer_name = ':'.join(storage_key_items[1:])
            current_model_key = storage_key_items[0]
            if current_model_key == component_model_key:
                buffer_object_class = get_proto_buffer_class(
                    model_class_map.get(storage_key, ''))
                if buffer_object_class:
                    buffer_object = buffer_object_class()
                else:
                    raise Exception(
                        'can not found this protobuffer class: {}'.format(
                            model_class_map.get(storage_key, '')))
                parse_proto_object(
                    proto_object=buffer_object,
                    proto_object_serialized_bytes=buffer_object_bytes)
                model_buffers[buffer_name] = buffer_object
    return model_buffers
Ejemplo n.º 2
0
def collect_pipeline_model(party_model_id, model_version):
    pipeline_model_table = storage.table(
        name=model_version,
        namespace=party_model_id,
        partition=get_model_table_partition_count(),
        create_if_missing=False,
        error_if_exist=False)
    model_buffers = {}
    if pipeline_model_table:
        model_class_map = storage.get_data_table_metas_by_instance(
            data_table=pipeline_model_table)
        for storage_key, buffer_object_bytes in pipeline_model_table.collect(
                use_serialize=False):
            storage_key_items = storage_key.split('.')
            buffer_name = storage_key_items[-1]
            buffer_object_class = get_proto_buffer_class(
                model_class_map.get(storage_key, ''))
            if buffer_object_class:
                buffer_object = buffer_object_class()
            else:
                raise Exception(
                    'can not found this protobuffer class: {}'.format(
                        model_class_map.get(storage_key, '')))
            buffer_object.ParseFromString(buffer_object_bytes)
            model_buffers[buffer_name] = buffer_object
    return model_buffers
Ejemplo n.º 3
0
def save_component_model(component_model_key,
                         model_buffers,
                         party_model_id,
                         model_version,
                         version_log=None):
    pipeline_model_table = storage.table(
        name=model_version,
        namespace=party_model_id,
        partition=get_model_table_partition_count(),
        create_if_missing=True,
        error_if_exist=False)
    model_class_map = {}
    for buffer_name, buffer_object in model_buffers.items():
        storage_key = '{}:{}'.format(component_model_key, buffer_name)
        buffer_object_serialize_string = buffer_object.SerializeToString()
        if not buffer_object_serialize_string:
            fill_message = default_empty_fill_pb2.DefaultEmptyFillMessage()
            fill_message.flag = 'set'
            buffer_object_serialize_string = fill_message.SerializeToString()
        pipeline_model_table.put(storage_key,
                                 buffer_object_serialize_string,
                                 use_serialize=False)
        model_class_map[storage_key] = type(buffer_object).__name__
    storage.save_data_table_meta(model_class_map,
                                 data_table_namespace=party_model_id,
                                 data_table_name=model_version)
    version_log = "[AUTO] save model at %s." % datetime.datetime.now(
    ) if not version_log else version_log
    version_control.save_version(name=model_version,
                                 namespace=party_model_id,
                                 version_log=version_log)
Ejemplo n.º 4
0
 def get_output_data_table(self, data_name: str = 'component'):
     output_data_info_table = storage.table(
         name=Tracking.output_table_name('data'),
         namespace=self.table_namespace)
     data_table_info = output_data_info_table.get(data_name)
     if data_table_info:
         data_table = storage.table(name=data_table_info.get('name', ''),
                                    namespace=data_table_info.get(
                                        'namespace', ''))
         data_table_meta = storage.get_data_table_metas_by_instance(
             data_table=data_table)
         if data_table_meta.get('schema', None):
             data_table.schema = data_table_meta['schema']
         return data_table
     else:
         return None
Ejemplo n.º 5
0
    def get_task_run_args(job_id, role, party_id, job_parameters, job_args,
                          input_dsl):
        task_run_args = {}
        for input_type, input_detail in input_dsl.items():
            if input_type == 'data':
                this_type_args = task_run_args[input_type] = task_run_args.get(
                    input_type, {})
                for data_type, data_list in input_detail.items():
                    for data_key in data_list:
                        data_key_item = data_key.split('.')
                        search_component_name, search_data_name = data_key_item[
                            0], data_key_item[1]
                        if search_component_name == 'args':
                            if job_args.get(
                                    'data', {}).get(search_data_name).get(
                                        'namespace', '') and job_args.get(
                                            'data',
                                            {}).get(search_data_name).get(
                                                'name', ''):

                                data_table = storage.table(
                                    namespace=job_args['data']
                                    [search_data_name]['namespace'],
                                    name=job_args['data'][search_data_name]
                                    ['name'])
                            else:
                                data_table = None
                        else:
                            data_table = Tracking(
                                job_id=job_id,
                                role=role,
                                party_id=party_id,
                                component_name=search_component_name
                            ).get_output_data_table(data_name=search_data_name)
                        args_from_component = this_type_args[
                            search_component_name] = this_type_args.get(
                                search_component_name, {})
                        args_from_component[data_type] = data_table
            elif input_type in ['model', 'isometric_model']:
                this_type_args = task_run_args[input_type] = task_run_args.get(
                    input_type, {})
                for dsl_model_key in input_detail:
                    dsl_model_key_items = dsl_model_key.split('.')
                    if len(dsl_model_key_items) == 2:
                        search_component_name, search_model_name = dsl_model_key_items[
                            0], dsl_model_key_items[1]
                    elif len(dsl_model_key_items
                             ) == 3 and dsl_model_key_items[0] == 'pipeline':
                        search_component_name, search_model_name = dsl_model_key_items[
                            1], dsl_model_key_items[2]
                    else:
                        raise Exception(
                            'get input {} failed'.format(input_type))
                    models = Tracking(
                        job_id=job_id,
                        role=role,
                        party_id=party_id,
                        component_name=search_component_name,
                        model_id=job_parameters['model_id'],
                        model_version=job_parameters['model_version']
                    ).get_output_model(model_name=search_model_name)
                    this_type_args[search_component_name] = models
        return task_run_args