Esempio n. 1
0
def save_component_model(component_model_key,
                         model_buffers,
                         party_model_id,
                         model_version,
                         version_log=None):
    pipeline_model_table = storage.table(
        name=model_version,
        namespace=party_model_id,
        partition=get_model_table_partition_count(),
        create_if_missing=True,
        error_if_exist=False)
    model_class_map = {}
    for buffer_name, buffer_object in model_buffers.items():
        storage_key = '{}:{}'.format(component_model_key, buffer_name)
        buffer_object_serialize_string = buffer_object.SerializeToString()
        if not buffer_object_serialize_string:
            fill_message = default_empty_fill_pb2.DefaultEmptyFillMessage()
            fill_message.flag = 'set'
            buffer_object_serialize_string = fill_message.SerializeToString()
        pipeline_model_table.put(storage_key,
                                 buffer_object_serialize_string,
                                 use_serialize=False)
        model_class_map[storage_key] = type(buffer_object).__name__
    storage.save_data_table_meta(model_class_map,
                                 data_table_namespace=party_model_id,
                                 data_table_name=model_version)
    version_log = "[AUTO] save model at %s." % datetime.datetime.now(
    ) if not version_log else version_log
    version_control.save_version(name=model_version,
                                 namespace=party_model_id,
                                 version_log=version_log)
Esempio n. 2
0
def save_data(data_inst, namespace, version):
    redis_adapter = RedisAdaptor()
    redis_adapter.setex(namespace, version)

    persistent_table = data_inst.save_as(namespace=namespace, name=version)
    LOGGER.info("save data to namespace={}, name={}".format(
        persistent_table._namespace, persistent_table._name))

    session.save_data_table_meta(
        {
            'schema': data_inst.schema,
            'header': data_inst.schema.get('header', [])
        },
        data_table_namespace=persistent_table._namespace,
        data_table_name=persistent_table._name)

    version_log = "[AUTO] save data at %s." % datetime.datetime.now()
    version_control.save_version(name=persistent_table._name,
                                 namespace=persistent_table._namespace,
                                 version_log=version_log)

    redis_adapter.delete(namespace)

    LOGGER.info('save table done, namepsace={}, version={}.'.format(
        persistent_table._namespace, persistent_table._name))
    return {
        'table_name': persistent_table._name,
        'namespace': persistent_table._namespace
    }
Esempio n. 3
0
def save_model(buffer_type, proto_buffer, name, namespace, version_log=None):
    data_table = eggroll.table(name=name, namespace=namespace, partition=get_model_table_partition_count(),
                               create_if_missing=True, error_if_exist=False)
    # todo:  model slice?
    data_table.put(buffer_type, proto_buffer.SerializeToString(), use_serialize=False)
    version_log = "[AUTO] save model at %s." % datetime.datetime.now() if not version_log else version_log
    version_control.save_version(name=name, namespace=namespace, version_log=version_log)
Esempio n. 4
0
    def save_data_table(self,
                        dst_table_name,
                        dst_table_namespace,
                        head=True,
                        in_version=False):
        input_file = self.parameters["file"]
        count = self.get_count(input_file)
        with open(input_file, 'r') as fin:
            lines_count = 0
            if head is True:
                data_head = fin.readline()
                count -= 1
                self.save_data_header(data_head, dst_table_name,
                                      dst_table_namespace)
                self.table_info["cols"] = data_head
            while True:
                data = list()
                lines = fin.readlines(self.MAX_BYTES)
                if lines:
                    for line in lines:
                        values = line.replace("\n", "").replace("\t",
                                                                ",").split(",")
                        data.append((values[0], self.list_to_str(values[1:])))
                    lines_count += len(data)
                    f_progress = lines_count / count * 100 // 1
                    job_info = {'f_progress': f_progress}
                    self.update_job_status(
                        self.parameters["local"]['role'],
                        self.parameters["local"]['party_id'], job_info)
                    data_table = session.save_data(
                        data,
                        name=dst_table_name,
                        namespace=dst_table_namespace,
                        partition=self.parameters["partition"])

                    self.table_info["v_len"] = data_table_count
                else:
                    self.tracker.save_data_view(
                        role=self.parameters["local"]['role'],
                        party_id=self.parameters["local"]['party_id'],
                        data_info={
                            'f_table_name': dst_table_name,
                            'f_table_namespace': dst_table_namespace,
                            'f_partition': self.parameters["partition"],
                            'f_table_count_actual': data_table.count(),
                            'f_table_count_upload': count
                        })
                    self.callback_metric(
                        metric_name='data_access',
                        metric_namespace='upload',
                        metric_data=[Metric("count", data_table.count())])
                    if in_version:
                        version_log = "[AUTO] save data at %s." % datetime.datetime.now(
                        )
                        version_control.save_version(
                            name=dst_table_name,
                            namespace=dst_table_namespace,
                            version_log=version_log)
                    return data_table.count()
Esempio n. 5
0
def save_data(kv_data: Iterable, name, namespace, partition=1, persistent: bool = True, create_if_missing=True, error_if_exist=False,
              in_version: bool = False, version_log=None):
    """
    save data into data table
    :param kv_data:
    :param name: table name of data table
    :param namespace: table namespace of data table
    :param partition: number of partition
    :param create_if_missing:
    :param error_if_exist:
    :return:
        data table instance
    """
    data_table = eggroll.table(name=name, namespace=namespace, partition=partition, persistent=persistent,
                               create_if_missing=create_if_missing, error_if_exist=error_if_exist)
    data_table.put_all(kv_data)
    if in_version:
        version_log = "[AUTO] save data at %s." % datetime.datetime.now() if not version_log else version_log
        version_control.save_version(name=name, namespace=namespace, version_log=version_log)
    return data_table
Esempio n. 6
0
    def save_data(self):
        #LOGGER.debug("save data: data_inst={}, count={}".format(self.data_processed, self.data_processed.count()))
        persistent_table = self.data_processed.save_as(
            namespace=self.model_param.save_out_table_namespace,
            name=self.model_param.save_out_table_name)
        LOGGER.info("save data to namespace={}, name={}".format(
            persistent_table._namespace, persistent_table._name))

        session.save_data_table_meta(
            {
                'schema': self.data_processed.schema,
                'header': self.data_processed.schema.get('header', [])
            },
            data_table_namespace=persistent_table._namespace,
            data_table_name=persistent_table._name)

        version_log = "[AUTO] save data at %s." % datetime.datetime.now()
        version_control.save_version(name=persistent_table._name,
                                     namespace=persistent_table._namespace,
                                     version_log=version_log)
        return None
Esempio n. 7
0
def save_table_version():
    save_version(name=request.json.get('name'),
                 namespace=request.json.get('namespace'),
                 version_log=request.json.get('version_log'))
    return get_json_result(retcode=0, retmsg='success')