예제 #1
0
 def tearDown(self):
     if gfile.Exists(self.data_source_f.output_base_dir):
         gfile.DeleteRecursively(self.data_source_f.output_base_dir)
     if gfile.Exists(self.data_source_l.output_base_dir):
         gfile.DeleteRecursively(self.data_source_l.output_base_dir)
     if gfile.Exists(self.raw_data_dir_l):
         gfile.DeleteRecursively(self.raw_data_dir_l)
     self.etcd.delete_prefix(common.data_source_etcd_base_dir(self.data_source_l.data_source_meta.name))
예제 #2
0
    def setUp(self) -> None:
        logging.getLogger().setLevel(logging.DEBUG)
        self._data_portal_name = 'test_data_portal_job_manager'

        self._kvstore = DBClient('etcd', True)
        self._portal_input_base_dir = './portal_input_dir'
        self._portal_output_base_dir = './portal_output_dir'
        self._raw_data_publish_dir = 'raw_data_publish_dir'
        if gfile.Exists(self._portal_input_base_dir):
            gfile.DeleteRecursively(self._portal_input_base_dir)
        gfile.MakeDirs(self._portal_input_base_dir)

        self._data_fnames = ['1001/{}.data'.format(i) for i in range(100)]
        self._data_fnames_without_success = \
            ['1002/{}.data'.format(i) for i in range(100)]
        self._csv_fnames = ['1003/{}.csv'.format(i) for i in range(100)]
        self._unused_fnames = ['{}.xx'.format(100)]
        self._all_fnames = self._data_fnames + \
                           self._data_fnames_without_success + \
                           self._csv_fnames + self._unused_fnames

        all_fnames_with_success = ['1001/_SUCCESS'] + ['1003/_SUCCESS'] +\
                                  self._all_fnames
        for fname in all_fnames_with_success:
            fpath = os.path.join(self._portal_input_base_dir, fname)
            gfile.MakeDirs(os.path.dirname(fpath))
            with gfile.Open(fpath, "w") as f:
                f.write('xxx')
 def test_raw_data_manifest_manager_with_nfs(self):
     root_dir = "test_fedlearner"
     os.environ["STORAGE_ROOT_PATH"] = root_dir
     cli = db_client.DBClient('nfs', True)
     self._raw_data_manifest_manager(cli)
     if gfile.Exists(root_dir):
         gfile.DeleteRecursively(root_dir)
예제 #4
0
    def _create_local_data(self, xl, xf, y):
        N = 10
        chunk_size = xl.shape[0]//N
        leader_worker_path = os.path.join(output_path, "data/leader")
        follower_worker_path = os.path.join(output_path, "data/follower")

        data_path = os.path.join(output_path, "data")
        if gfile.Exists(data_path):
            gfile.DeleteRecursively(data_path)
        os.makedirs(leader_worker_path)
        os.makedirs(follower_worker_path)

        for i in range(N):
            filename_l = os.path.join(leader_worker_path, '%02d.tfrecord'%i)
            filename_f = os.path.join(follower_worker_path, '%02d.tfrecord'%i)
            fl = tf.io.TFRecordWriter(filename_l)
            ff = tf.io.TFRecordWriter(filename_f)
            for j in range(chunk_size):
                idx = i*chunk_size + j
                features_l = {}
                features_l['example_id'] = Feature(
                    bytes_list=BytesList(value=[str(idx).encode('utf-8')]))
                features_l['y'] = Feature(int64_list=Int64List(value=[y[idx]]))
                features_l['x'] = Feature(float_list=FloatList(value=list(xl[idx])))
                fl.write(
                    Example(features=Features(feature=features_l)).SerializeToString())
                features_f = {}
                features_f['example_id'] = Feature(
                    bytes_list=BytesList(value=[str(idx).encode('utf-8')]))
                features_f['x'] = Feature(float_list=FloatList(value=list(xf[idx])))
                ff.write(
                    Example(features=Features(feature=features_f)).SerializeToString())
            fl.close()
            ff.close()
예제 #5
0
    def test_potral_hourly_input_reducer_mapper(self):
        self._prepare_test()
        reducer = PotralHourlyInputReducer(self._portal_manifest,
                                           self._portal_options,
                                           self._date_time)
        mapper = PotralHourlyOutputMapper(self._portal_manifest,
                                          self._portal_options,
                                          self._date_time)
        expected_example_idx = 0
        for tf_item in reducer.make_reducer():
            example_id = '{}'.format(expected_example_idx).encode()
            mapper.map_data(tf_item)
            self.assertEqual(tf_item.example_id, example_id)
            expected_example_idx += 1
            if expected_example_idx % 7 == 0:
                expected_example_idx += 1
        mapper.finish_map()
        for partition_id in range(self._portal_manifest.output_partition_num):
            fpath = common.encode_portal_hourly_fpath(
                self._portal_manifest.output_data_base_dir, self._date_time,
                partition_id)
            freader = PotralHourlyInputReducer.InputFileReader(
                partition_id, fpath, self._portal_options)
            for example_idx in range(self._total_item_num):
                example_id = '{}'.format(example_idx).encode()
                if example_idx != 0 and (example_idx % 7) == 0:
                    continue
                if partition_id != CityHash32(example_id) % \
                        self._portal_manifest.output_partition_num:
                    continue
                for item in freader:
                    self.assertEqual(example_id,
                                     item.tf_example_item.example_id)
                    break
                self.assertFalse(freader.finished)
            try:
                next(freader)
            except StopIteration:
                self.assertTrue(True)
            else:
                self.assertTrue(False)
            self.assertTrue(freader.finished)

        if gfile.Exists(self._portal_manifest.input_data_base_dir):
            gfile.DeleteRecursively(self._portal_manifest.input_data_base_dir)
        if gfile.Exists(self._portal_manifest.output_data_base_dir):
            gfile.DeleteRecursively(self._portal_manifest.output_data_base_dir)
예제 #6
0
def make_ckpt_dir(role, remote="local", rank=None):
    if rank is None:
        rank = "N"
    ckpt_path = "{}/{}_ckpt_{}_{}".format(output_path, remote, role, rank)
    exp_path = "{}/saved_model".format(ckpt_path)
    if gfile.Exists(ckpt_path):
        gfile.DeleteRecursively(ckpt_path)
    return ckpt_path, exp_path
예제 #7
0
 def _generate_input_data(self):
     self._total_item_num = 1 << 16
     self.assertEqual(
         self._total_item_num % self._portal_manifest.input_partition_num,
         0)
     if gfile.Exists(self._portal_manifest.input_data_base_dir):
         gfile.DeleteRecursively(self._portal_manifest.input_data_base_dir)
     if gfile.Exists(self._portal_manifest.output_data_base_dir):
         gfile.DeleteRecursively(self._portal_manifest.output_data_base_dir)
     hourly_dir = common.encode_portal_hourly_dir(
         self._portal_manifest.input_data_base_dir, self._date_time)
     gfile.MakeDirs(hourly_dir)
     for partition_id in range(self._portal_manifest.input_partition_num):
         self._generate_one_part(partition_id)
     succ_tag_fpath = common.encode_portal_hourly_finish_tag(
         self._portal_manifest.input_data_base_dir, self._date_time)
     with gfile.GFile(succ_tag_fpath, 'w') as fh:
         fh.write('')
예제 #8
0
 def setUp(self):
     self.data_source = common_pb.DataSource()
     self.data_source.data_source_meta.name = 'fclh_test'
     self.data_source.data_source_meta.partition_num = 1
     self.raw_data_dir = "./raw_data"
     self.kvstore = db_client.DBClient('etcd', True)
     self.kvstore.delete_prefix(common.data_source_kvstore_base_dir(self.data_source.data_source_meta.name))
     self.assertEqual(self.data_source.data_source_meta.partition_num, 1)
     partition_dir = os.path.join(self.raw_data_dir, common.partition_repr(0))
     if gfile.Exists(partition_dir):
         gfile.DeleteRecursively(partition_dir)
     gfile.MakeDirs(partition_dir)
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.kvstore, self.data_source)
예제 #9
0
 def _start_new(self):
     if self._task is not None and self._task.is_alive():
         logging.info(" %s is alive, no need to start new" % self.name)
         return
     self._task = Process(target=self._target, name=self.name,
                          args=self._args, kwargs=self._kwargs, daemon=self._daemon)
     if isinstance(self._args[0], Args):
         logging.info("delete %s", self._args[0].export_path)
         if gfile.Exists(self._args[0].export_path):
             logging.info(" deleting")
             gfile.DeleteRecursively(self._args[0].export_path)
     self._task.start()
     logging.info("Task starts %s" % self.name)
     time.sleep(10)
예제 #10
0
 def setUp(self):
     data_source_f = common_pb.DataSource()
     data_source_f.data_source_meta.name = "milestone"
     data_source_f.data_source_meta.partition_num = 1
     data_source_f.output_base_dir = "./output-f"
     self.data_source_f = data_source_f
     if gfile.Exists(self.data_source_f.output_base_dir):
         gfile.DeleteRecursively(self.data_source_f.output_base_dir)
     data_source_l = common_pb.DataSource()
     data_source_l.data_source_meta.name = "milestone"
     data_source_l.data_source_meta.partition_num = 1
     data_source_l.output_base_dir = "./output-l"
     self.raw_data_dir_l = "./raw_data-l"
     self.data_source_l = data_source_l
     if gfile.Exists(self.data_source_l.output_base_dir):
         gfile.DeleteRecursively(self.data_source_l.output_base_dir)
     if gfile.Exists(self.raw_data_dir_l):
         gfile.DeleteRecursively(self.raw_data_dir_l)
     self.kvstore = db_client.DBClient('etcd', True)
     self.kvstore.delete_prefix(
         common.data_source_kvstore_base_dir(
             self.data_source_l.data_source_meta.name))
     self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
         self.kvstore, self.data_source_l)
예제 #11
0
 def _prepare(self):
     self._output_dir = os.path.join(
         self._options.output_dir,
         common.partition_repr(self._partition_id))
     if gfile.Exists(self._output_dir):
         gfile.DeleteRecursively(self._output_dir)
     gfile.MkDir(self._options.output_dir)
     gfile.MkDir(self._output_dir)
     for fpath_id, fpath in enumerate(self._fpaths):
         fpath = "{}/{}".format(self._options.input_dir, fpath)
         reader = Merge.InputFileReader(fpath_id, fpath, self._options)
         self._readers.append(reader)
         self._active_fpath.add(fpath_id)
         logging.info("Merge partition_id:%d, path:%s", self._partition_id,
                      fpath)
     self._preload_queue()
예제 #12
0
    def setUp(self):
        self.sche = _TaskScheduler(30)
        self.kv_store = [None, None]
        self.app_id = "test_trainer_v1"
        db_database, db_addr, db_username, db_password, db_base_dir = \
                get_kvstore_config("etcd")
        data_source = [
            self._gen_ds_meta(common_pb.FLRole.Leader),
            self._gen_ds_meta(common_pb.FLRole.Follower)
        ]
        for role in range(2):
            self.kv_store[role] = mysql_client.DBClient(
                data_source[role].data_source_meta.name, db_addr, db_username,
                db_password, db_base_dir, True)
        self.data_source = data_source
        (x, y) = (None, None)
        if debug_mode:
            (x, y), _ = tf.keras.datasets.mnist.load_data(local_mnist_path)
        else:
            (x, y), _ = tf.keras.datasets.mnist.load_data()
        x = x[:200, ]

        x = x.reshape(x.shape[0], -1).astype(np.float32) / 255.0
        y = y.astype(np.int64)

        xl = x[:, :x.shape[1] // 2]
        xf = x[:, x.shape[1] // 2:]

        self._create_local_data(xl, xf, y)

        x = [xl, xf]
        for role in range(2):
            common.commit_data_source(self.kv_store[role], data_source[role])
            if gfile.Exists(data_source[role].output_base_dir):
                gfile.DeleteRecursively(data_source[role].output_base_dir)
            manifest_manager = raw_data_manifest_manager.RawDataManifestManager(
                self.kv_store[role], data_source[role])
            partition_num = data_source[role].data_source_meta.partition_num
            for i in range(partition_num):
                self._create_data_block(data_source[role], i, x[role], y)
                #x[role], y if role == 0 else None)

                manifest_manager._finish_partition(
                    'join_example_rep', dj_pb.JoinExampleState.UnJoined,
                    dj_pb.JoinExampleState.Joined, -1, i)