Ejemplo n.º 1
0
 def _cheif_barriar(self, is_chief=False, sync_times=300):
     worker_replicas = os.environ.get('REPLICA_NUM', 0)
     kvstore_type = os.environ.get('KVSTORE_TYPE', 'etcd')
     db_database, db_addr, db_username, db_password, _ = \
         get_kvstore_config(kvstore_type)
     kvstore_client = DBClient(db_database,
                               db_addr,
                               db_username,
                               db_password,
                               SYNC_PATH)
     sync_path = '%s/%s' % (os.environ['APPLICATION_ID'],
                            os.environ['WORKER_RANK'])
     logging.info('Creating a sync flag at %s', sync_path)
     kvstore_client.set_data(sync_path, "1")
     if is_chief:
         for _ in range(sync_times):
             sync_list = kvstore_client.get_prefix_kvs(
                 os.environ['APPLICATION_ID'])
             logging.info('Sync file pattern is: %s', sync_list)
             if len(sync_list) < worker_replicas:
                 logging.info('Count of ready workers is %d',
                              len(sync_list))
                 time.sleep(6)
             else:
                 break
Ejemplo n.º 2
0
    def setUp(self):
        self.sche = _TaskScheduler(30)
        self.kv_store = [None, None]
        self.app_id = "test_trainer_v1"
        db_database, db_addr, db_username, db_password, db_base_dir = \
                get_kvstore_config("etcd")
        data_source = [
            self._gen_ds_meta(common_pb.FLRole.Leader),
            self._gen_ds_meta(common_pb.FLRole.Follower)
        ]
        for role in range(2):
            self.kv_store[role] = mysql_client.DBClient(
                data_source[role].data_source_meta.name, db_addr, db_username,
                db_password, db_base_dir, True)
        self.data_source = data_source
        (x, y) = (None, None)
        if debug_mode:
            (x, y), _ = tf.keras.datasets.mnist.load_data(local_mnist_path)
        else:
            (x, y), _ = tf.keras.datasets.mnist.load_data()
        x = x[:200, ]

        x = x.reshape(x.shape[0], -1).astype(np.float32) / 255.0
        y = y.astype(np.int64)

        xl = x[:, :x.shape[1] // 2]
        xf = x[:, x.shape[1] // 2:]

        self._create_local_data(xl, xf, y)

        x = [xl, xf]
        for role in range(2):
            common.commit_data_source(self.kv_store[role], data_source[role])
            if gfile.Exists(data_source[role].output_base_dir):
                gfile.DeleteRecursively(data_source[role].output_base_dir)
            manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
                self.kv_store[role], data_source[role])
            partition_num = data_source[role].data_source_meta.partition_num
            for i in range(partition_num):
                self._create_data_block(data_source[role], i, x[role], y)
                #x[role], y if role == 0 else None)

                manifest_manager._finish_partition(
                    'join_example_rep', dj_pb.JoinExampleState.UnJoined,
                    dj_pb.JoinExampleState.Joined, -1, i)
Ejemplo n.º 3
0
    def __init__(self,
                 base_path,
                 name,
                 role,
                 partition_num=1,
                 start_time=0,
                 end_time=100000):
        if role == 'leader':
            role = 0
        elif role == 'follower':
            role = 1
        else:
            raise ValueError("Unknown role %s" % role)
        data_source = common_pb.DataSource()
        data_source.data_source_meta.name = name
        data_source.data_source_meta.partition_num = partition_num
        data_source.data_source_meta.start_time = start_time
        data_source.data_source_meta.end_time = end_time
        data_source.output_base_dir = "{}/{}_{}/data_source/".format(
            base_path, data_source.data_source_meta.name, role)
        data_source.role = role
        if gfile.Exists(data_source.output_base_dir):
            gfile.DeleteRecursively(data_source.output_base_dir)

        self._data_source = data_source

        db_database, db_addr, db_username, db_password, db_base_dir = \
            get_kvstore_config("etcd")
        self._kv_store = mysql_client.DBClient(db_database, db_addr,
                                               db_username, db_password,
                                               db_base_dir, True)

        common.commit_data_source(self._kv_store, self._data_source)
        self._dbms = []
        for i in range(partition_num):
            manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
                self._kv_store, self._data_source)
            manifest_manager._finish_partition('join_example_rep',
                                               dj_pb.JoinExampleState.UnJoined,
                                               dj_pb.JoinExampleState.Joined,
                                               -1, i)
            self._dbms.append(
                data_block_manager.DataBlockManager(self._data_source, i))
Ejemplo n.º 4
0
import collections
import traceback
import tensorflow.compat.v1 as tf

from fedlearner.common import trainer_master_service_pb2 as tm_pb
from fedlearner.common import trainer_master_service_pb2_grpc as tm_grpc
from fedlearner.proxy.channel import make_insecure_channel, ChannelType
from fedlearner.common import common_pb2 as common_pb
from fedlearner.data_join.data_block_visitor import DataBlockVisitor
from fedlearner.data_join.common import get_kvstore_config

DataBlockInfo = collections.namedtuple('DataBlockInfo',
                                       ['block_id', 'data_path'])
kvstore_type = os.environ.get('KVSTORE_TYPE', 'etcd')
db_database, db_addr, db_username, db_password, db_base_dir = \
    get_kvstore_config(kvstore_type)


class LocalTrainerMasterClient(object):
    """Non-thread safe"""
    def __init__(self,
                 role,
                 path,
                 files=None,
                 ext='.tfrecord',
                 start_time=None,
                 end_time=None,
                 from_data_source=False,
                 skip_datablock_checkpoint=False,
                 epoch_num=1):
        self._role = role
                args.sort_run_merger_read_batch_size,
            batch_processor_options=dj_pb.BatchProcessorOptions(
                batch_size=args.process_batch_size,
                max_flying_item=-1
            ),
            input_raw_data=dj_pb.RawDataOptions(
                raw_data_iter=args.raw_data_iter,
                compressed_type=args.compressed_type,
                read_ahead_size=args.read_ahead_size,
                read_batch_size=args.read_batch_size
            ),
            writer_options=dj_pb.WriterOptions(
                output_writer=args.output_builder,
                compressed_type=args.builder_compressed_type,
            )
        )
    if args.psi_role.upper() == 'LEADER':
        preprocessor_options.role = common_pb.FLRole.Leader
    else:
        assert args.psi_role.upper() == 'FOLLOWER'
        preprocessor_options.role = common_pb.FLRole.Follower
    db_database, db_addr, db_username, db_password, db_base_dir = \
        get_kvstore_config(args.kvstore_type)
    preprocessor = RsaPsiPreProcessor(preprocessor_options, db_database,
                                      db_base_dir, db_addr, db_username,
                                      db_password)
    preprocessor.start_process()
    logging.info("PreProcessor launched for %s of RSA PSI", args.psi_role)
    preprocessor.wait_for_finished()
    logging.info("PreProcessor finished for %s of RSA PSI", args.psi_role)
Ejemplo n.º 6
0
import etcd3
from fedlearner.common.mysql_client import DBClient
from fedlearner.data_join.common import get_kvstore_config

database, addr, username, password, base_dir = \
    get_kvstore_config('mysql')
MySQL_client = DBClient(database, addr, username, password, base_dir)
database, addr, username, password, base_dir = \
    get_kvstore_config('etcd')
(host, port) = addr.split(':')
options = [('grpc.max_send_message_length', 2**31-1),
    ('grpc.max_receive_message_length', 2**31-1)]
clnt = etcd3.client(host=host, port=port,
    grpc_options=options)
for (data, key) in clnt.get_prefix('/', sort_order='ascend'):
    if not isinstance(key.key, str):
        key = key.key.decoder()
    if not isinstance(data, str):
        data = data.decoder()
    MySQL_client.set_data(key, data)