def tearDown(self): if gfile.Exists(self.data_source_f.output_base_dir): gfile.DeleteRecursively(self.data_source_f.output_base_dir) if gfile.Exists(self.data_source_l.output_base_dir): gfile.DeleteRecursively(self.data_source_l.output_base_dir) if gfile.Exists(self.raw_data_dir_l): gfile.DeleteRecursively(self.raw_data_dir_l) self.etcd.delete_prefix(common.data_source_etcd_base_dir(self.data_source_l.data_source_meta.name))
def setUp(self) -> None: logging.getLogger().setLevel(logging.DEBUG) self._data_portal_name = 'test_data_portal_job_manager' self._kvstore = DBClient('etcd', True) self._portal_input_base_dir = './portal_input_dir' self._portal_output_base_dir = './portal_output_dir' self._raw_data_publish_dir = 'raw_data_publish_dir' if gfile.Exists(self._portal_input_base_dir): gfile.DeleteRecursively(self._portal_input_base_dir) gfile.MakeDirs(self._portal_input_base_dir) self._data_fnames = ['1001/{}.data'.format(i) for i in range(100)] self._data_fnames_without_success = \ ['1002/{}.data'.format(i) for i in range(100)] self._csv_fnames = ['1003/{}.csv'.format(i) for i in range(100)] self._unused_fnames = ['{}.xx'.format(100)] self._all_fnames = self._data_fnames + \ self._data_fnames_without_success + \ self._csv_fnames + self._unused_fnames all_fnames_with_success = ['1001/_SUCCESS'] + ['1003/_SUCCESS'] +\ self._all_fnames for fname in all_fnames_with_success: fpath = os.path.join(self._portal_input_base_dir, fname) gfile.MakeDirs(os.path.dirname(fpath)) with gfile.Open(fpath, "w") as f: f.write('xxx')
def test_raw_data_manifest_manager_with_nfs(self): root_dir = "test_fedlearner" os.environ["STORAGE_ROOT_PATH"] = root_dir cli = db_client.DBClient('nfs', True) self._raw_data_manifest_manager(cli) if gfile.Exists(root_dir): gfile.DeleteRecursively(root_dir)
def _create_local_data(self, xl, xf, y): N = 10 chunk_size = xl.shape[0]//N leader_worker_path = os.path.join(output_path, "data/leader") follower_worker_path = os.path.join(output_path, "data/follower") data_path = os.path.join(output_path, "data") if gfile.Exists(data_path): gfile.DeleteRecursively(data_path) os.makedirs(leader_worker_path) os.makedirs(follower_worker_path) for i in range(N): filename_l = os.path.join(leader_worker_path, '%02d.tfrecord'%i) filename_f = os.path.join(follower_worker_path, '%02d.tfrecord'%i) fl = tf.io.TFRecordWriter(filename_l) ff = tf.io.TFRecordWriter(filename_f) for j in range(chunk_size): idx = i*chunk_size + j features_l = {} features_l['example_id'] = Feature( bytes_list=BytesList(value=[str(idx).encode('utf-8')])) features_l['y'] = Feature(int64_list=Int64List(value=[y[idx]])) features_l['x'] = Feature(float_list=FloatList(value=list(xl[idx]))) fl.write( Example(features=Features(feature=features_l)).SerializeToString()) features_f = {} features_f['example_id'] = Feature( bytes_list=BytesList(value=[str(idx).encode('utf-8')])) features_f['x'] = Feature(float_list=FloatList(value=list(xf[idx]))) ff.write( Example(features=Features(feature=features_f)).SerializeToString()) fl.close() ff.close()
def test_potral_hourly_input_reducer_mapper(self): self._prepare_test() reducer = PotralHourlyInputReducer(self._portal_manifest, self._portal_options, self._date_time) mapper = PotralHourlyOutputMapper(self._portal_manifest, self._portal_options, self._date_time) expected_example_idx = 0 for tf_item in reducer.make_reducer(): example_id = '{}'.format(expected_example_idx).encode() mapper.map_data(tf_item) self.assertEqual(tf_item.example_id, example_id) expected_example_idx += 1 if expected_example_idx % 7 == 0: expected_example_idx += 1 mapper.finish_map() for partition_id in range(self._portal_manifest.output_partition_num): fpath = common.encode_portal_hourly_fpath( self._portal_manifest.output_data_base_dir, self._date_time, partition_id) freader = PotralHourlyInputReducer.InputFileReader( partition_id, fpath, self._portal_options) for example_idx in range(self._total_item_num): example_id = '{}'.format(example_idx).encode() if example_idx != 0 and (example_idx % 7) == 0: continue if partition_id != CityHash32(example_id) % \ self._portal_manifest.output_partition_num: continue for item in freader: self.assertEqual(example_id, item.tf_example_item.example_id) break self.assertFalse(freader.finished) try: next(freader) except StopIteration: self.assertTrue(True) else: self.assertTrue(False) self.assertTrue(freader.finished) if gfile.Exists(self._portal_manifest.input_data_base_dir): gfile.DeleteRecursively(self._portal_manifest.input_data_base_dir) if gfile.Exists(self._portal_manifest.output_data_base_dir): gfile.DeleteRecursively(self._portal_manifest.output_data_base_dir)
def make_ckpt_dir(role, remote="local", rank=None): if rank is None: rank = "N" ckpt_path = "{}/{}_ckpt_{}_{}".format(output_path, remote, role, rank) exp_path = "{}/saved_model".format(ckpt_path) if gfile.Exists(ckpt_path): gfile.DeleteRecursively(ckpt_path) return ckpt_path, exp_path
def _generate_input_data(self): self._total_item_num = 1 << 16 self.assertEqual( self._total_item_num % self._portal_manifest.input_partition_num, 0) if gfile.Exists(self._portal_manifest.input_data_base_dir): gfile.DeleteRecursively(self._portal_manifest.input_data_base_dir) if gfile.Exists(self._portal_manifest.output_data_base_dir): gfile.DeleteRecursively(self._portal_manifest.output_data_base_dir) hourly_dir = common.encode_portal_hourly_dir( self._portal_manifest.input_data_base_dir, self._date_time) gfile.MakeDirs(hourly_dir) for partition_id in range(self._portal_manifest.input_partition_num): self._generate_one_part(partition_id) succ_tag_fpath = common.encode_portal_hourly_finish_tag( self._portal_manifest.input_data_base_dir, self._date_time) with gfile.GFile(succ_tag_fpath, 'w') as fh: fh.write('')
def setUp(self): self.data_source = common_pb.DataSource() self.data_source.data_source_meta.name = 'fclh_test' self.data_source.data_source_meta.partition_num = 1 self.raw_data_dir = "./raw_data" self.kvstore = db_client.DBClient('etcd', True) self.kvstore.delete_prefix(common.data_source_kvstore_base_dir(self.data_source.data_source_meta.name)) self.assertEqual(self.data_source.data_source_meta.partition_num, 1) partition_dir = os.path.join(self.raw_data_dir, common.partition_repr(0)) if gfile.Exists(partition_dir): gfile.DeleteRecursively(partition_dir) gfile.MakeDirs(partition_dir) self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.kvstore, self.data_source)
def _start_new(self): if self._task is not None and self._task.is_alive(): logging.info(" %s is alive, no need to start new" % self.name) return self._task = Process(target=self._target, name=self.name, args=self._args, kwargs=self._kwargs, daemon=self._daemon) if isinstance(self._args[0], Args): logging.info("delete %s", self._args[0].export_path) if gfile.Exists(self._args[0].export_path): logging.info(" deleting") gfile.DeleteRecursively(self._args[0].export_path) self._task.start() logging.info("Task starts %s" % self.name) time.sleep(10)
def setUp(self): data_source_f = common_pb.DataSource() data_source_f.data_source_meta.name = "milestone" data_source_f.data_source_meta.partition_num = 1 data_source_f.output_base_dir = "./output-f" self.data_source_f = data_source_f if gfile.Exists(self.data_source_f.output_base_dir): gfile.DeleteRecursively(self.data_source_f.output_base_dir) data_source_l = common_pb.DataSource() data_source_l.data_source_meta.name = "milestone" data_source_l.data_source_meta.partition_num = 1 data_source_l.output_base_dir = "./output-l" self.raw_data_dir_l = "./raw_data-l" self.data_source_l = data_source_l if gfile.Exists(self.data_source_l.output_base_dir): gfile.DeleteRecursively(self.data_source_l.output_base_dir) if gfile.Exists(self.raw_data_dir_l): gfile.DeleteRecursively(self.raw_data_dir_l) self.kvstore = db_client.DBClient('etcd', True) self.kvstore.delete_prefix( common.data_source_kvstore_base_dir( self.data_source_l.data_source_meta.name)) self.manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.kvstore, self.data_source_l)
def _prepare(self): self._output_dir = os.path.join( self._options.output_dir, common.partition_repr(self._partition_id)) if gfile.Exists(self._output_dir): gfile.DeleteRecursively(self._output_dir) gfile.MkDir(self._options.output_dir) gfile.MkDir(self._output_dir) for fpath_id, fpath in enumerate(self._fpaths): fpath = "{}/{}".format(self._options.input_dir, fpath) reader = Merge.InputFileReader(fpath_id, fpath, self._options) self._readers.append(reader) self._active_fpath.add(fpath_id) logging.info("Merge partition_id:%d, path:%s", self._partition_id, fpath) self._preload_queue()
def setUp(self): self.sche = _TaskScheduler(30) self.kv_store = [None, None] self.app_id = "test_trainer_v1" db_database, db_addr, db_username, db_password, db_base_dir = \ get_kvstore_config("etcd") data_source = [ self._gen_ds_meta(common_pb.FLRole.Leader), self._gen_ds_meta(common_pb.FLRole.Follower) ] for role in range(2): self.kv_store[role] = mysql_client.DBClient( data_source[role].data_source_meta.name, db_addr, db_username, db_password, db_base_dir, True) self.data_source = data_source (x, y) = (None, None) if debug_mode: (x, y), _ = tf.keras.datasets.mnist.load_data(local_mnist_path) else: (x, y), _ = tf.keras.datasets.mnist.load_data() x = x[:200, ] x = x.reshape(x.shape[0], -1).astype(np.float32) / 255.0 y = y.astype(np.int64) xl = x[:, :x.shape[1] // 2] xf = x[:, x.shape[1] // 2:] self._create_local_data(xl, xf, y) x = [xl, xf] for role in range(2): common.commit_data_source(self.kv_store[role], data_source[role]) if gfile.Exists(data_source[role].output_base_dir): gfile.DeleteRecursively(data_source[role].output_base_dir) manifest_manager = raw_data_manifest_manager.RawDataManifestManager( self.kv_store[role], data_source[role]) partition_num = data_source[role].data_source_meta.partition_num for i in range(partition_num): self._create_data_block(data_source[role], i, x[role], y) #x[role], y if role == 0 else None) manifest_manager._finish_partition( 'join_example_rep', dj_pb.JoinExampleState.UnJoined, dj_pb.JoinExampleState.Joined, -1, i)