Пример #1
0
def save_component_model(component_model_key,
                         model_buffers,
                         party_model_id,
                         model_version,
                         version_log=None):
    pipeline_model_table = storage.table(
        name=model_version,
        namespace=party_model_id,
        partition=get_model_table_partition_count(),
        create_if_missing=True,
        error_if_exist=False)
    model_class_map = {}
    for buffer_name, buffer_object in model_buffers.items():
        storage_key = '{}:{}'.format(component_model_key, buffer_name)
        buffer_object_serialize_string = buffer_object.SerializeToString()
        if not buffer_object_serialize_string:
            fill_message = default_empty_fill_pb2.DefaultEmptyFillMessage()
            fill_message.flag = 'set'
            buffer_object_serialize_string = fill_message.SerializeToString()
        pipeline_model_table.put(storage_key,
                                 buffer_object_serialize_string,
                                 use_serialize=False)
        model_class_map[storage_key] = type(buffer_object).__name__
    storage.save_data_table_meta(model_class_map,
                                 data_table_namespace=party_model_id,
                                 data_table_name=model_version)
    version_log = "[AUTO] save model at %s." % datetime.datetime.now(
    ) if not version_log else version_log
    version_control.save_version(name=model_version,
                                 namespace=party_model_id,
                                 version_log=version_log)
Пример #2
0
 def save_component_model(self, component_name, component_module_name,
                          model_alias, model_buffers):
     model_proto_index = {}
     component_model_storage_path = os.path.join(self.variables_data_path,
                                                 component_name,
                                                 model_alias)
     os.makedirs(component_model_storage_path, exist_ok=True)
     for model_name, buffer_object in model_buffers.items():
         storage_path = os.path.join(component_model_storage_path,
                                     model_name)
         buffer_object_serialized_string = buffer_object.SerializeToString()
         if not buffer_object_serialized_string:
             fill_message = default_empty_fill_pb2.DefaultEmptyFillMessage()
             fill_message.flag = 'set'
             buffer_object_serialized_string = fill_message.SerializeToString(
             )
         with open(storage_path, "wb") as fw:
             fw.write(buffer_object_serialized_string)
         model_proto_index[model_name] = type(
             buffer_object
         ).__name__  # index of model name and proto buffer class name
         stat_logger.info("Save {} {} {} buffer".format(
             component_name, model_alias, model_name))
     self.update_component_meta(component_name=component_name,
                                component_module_name=component_module_name,
                                model_alias=model_alias,
                                model_proto_index=model_proto_index)
     stat_logger.info("Save {} {} successfully".format(
         component_name, model_alias))
Пример #3
0
 def save_pipeline(self, pipelined_buffer_object):
     buffer_object_serialized_string = pipelined_buffer_object.SerializeToString()
     if not buffer_object_serialized_string:
         fill_message = default_empty_fill_pb2.DefaultEmptyFillMessage()
         fill_message.flag = 'set'
         buffer_object_serialized_string = fill_message.SerializeToString()
     with open(os.path.join(self.model_path, "pipeline.pb"), "wb") as fw:
         fw.write(buffer_object_serialized_string)
Пример #4
0
def parse_proto_object(proto_object, proto_object_serialized_bytes):
    try:
        proto_object.ParseFromString(proto_object_serialized_bytes)
        stat_logger.info('parse {} proto object normal'.format(
            type(proto_object).__name__))
    except Exception as e1:
        try:
            fill_message = default_empty_fill_pb2.DefaultEmptyFillMessage()
            fill_message.ParseFromString(proto_object_serialized_bytes)
            proto_object.ParseFromString(bytes())
            stat_logger.info(
                'parse {} proto object with default values'.format(
                    type(proto_object).__name__))
        except Exception as e2:
            stat_logger.exception(e2)
            raise e1
Пример #5
0
 def parse_proto_object(self, buffer_name, buffer_object_serialized_string):
     try:
         buffer_object = self.get_proto_buffer_class(buffer_name)()
     except Exception as e:
         stat_logger.exception("Can not restore proto buffer object", e)
         raise e
     try:
         buffer_object.ParseFromString(buffer_object_serialized_string)
         stat_logger.info('parse {} proto object normal'.format(type(buffer_object).__name__))
         return buffer_object
     except Exception as e1:
         try:
             fill_message = default_empty_fill_pb2.DefaultEmptyFillMessage()
             fill_message.ParseFromString(buffer_object_serialized_string)
             buffer_object.ParseFromString(bytes())
             stat_logger.info('parse {} proto object with default values'.format(type(buffer_object).__name__))
             return buffer_object
         except Exception as e2:
             stat_logger.exception(e2)
             raise e1