def __init__( self, config: ConfigSchema, model: Optional[MultiRelationEmbedder] = None, trainer: Optional[AbstractBatchProcessor] = None, evaluator: Optional[AbstractBatchProcessor] = None, rank: Rank = SINGLE_TRAINER, subprocess_init: Optional[Callable[[], None]] = None, stats_handler: StatsHandler = NOOP_STATS_HANDLER, ): super().__init__( config, model, trainer, evaluator, rank, subprocess_init, stats_handler ) assert config.num_gpus > 0 if not CPP_INSTALLED: raise RuntimeError( "GPU support requires C++ installation: " "install with C++ support by running " "`PBG_INSTALL_CPP=1 pip install .`" ) if config.half_precision: for entity in config.entities: # need this for tensor cores to work assert config.entity_dimension(entity) % 8 == 0 assert config.batch_size % 8 == 0 assert config.num_batch_negs % 8 == 0 assert config.num_uniform_negs % 8 == 0 assert len(self.holder.lhs_unpartitioned_types) == 0 assert len(self.holder.rhs_unpartitioned_types) == 0 num_edge_chunks = self.iteration_manager.num_edge_chunks max_edges = 0 for edge_path in config.edge_paths: edge_storage = EDGE_STORAGES.make_instance(edge_path) for lhs_part in range(self.holder.nparts_lhs): for rhs_part in range(self.holder.nparts_rhs): num_edges = edge_storage.get_number_of_edges(lhs_part, rhs_part) num_edges_per_chunk = div_roundup(num_edges, num_edge_chunks) max_edges = max(max_edges, num_edges_per_chunk) self.shared_lhs = allocate_shared_tensor((max_edges,), dtype=torch.long) self.shared_rhs = allocate_shared_tensor((max_edges,), dtype=torch.long) self.shared_rel = allocate_shared_tensor((max_edges,), dtype=torch.long) # fork early for HOGWILD threads logger.info("Creating GPU workers...") torch.set_num_threads(1) self.gpu_pool = GPUProcessPool( config.num_gpus, subprocess_init, {s for ss in self.embedding_storage_freelist.values() for s in ss} | { self.shared_lhs.storage(), self.shared_rhs.storage(), self.shared_rel.storage(), }, )
def load_embeddings(hf: h5py.File) -> FloatTensorType: dataset: h5py.Dataset = hf[EMBEDDING_DATASET] embeddings = allocate_shared_tensor(dataset.shape, dtype=torch.float) # Needed because https://github.com/h5py/h5py/issues/870. if dataset.size > 0: dataset.read_direct(embeddings.numpy()) return embeddings
def get(self, key: str, dst: Optional[torch.Tensor] = None, shared: bool = False) -> Optional[torch.Tensor]: """Get a tensor from the server.""" self._validate_get(key, dst=dst, shared=shared) cmd_rpc = torch.tensor( [GET_CMD, len(key), dst is None, 0, 0, 0], dtype=torch.long) metadata_pg = self._metadata_pg() td.send(cmd_rpc, self.server_rank, group=metadata_pg) td.send(_fromstring(key), self.server_rank, group=metadata_pg) if dst is None: meta = torch.full((2, ), -1, dtype=torch.long) td.recv(meta, src=self.server_rank, group=metadata_pg) ndim, ttype = meta if ndim.item() == -1: return None size = torch.full((ndim.item(), ), -1, dtype=torch.long) td.recv(size, src=self.server_rank, group=metadata_pg) dtype = _dtypes[ttype.item()] if shared: dst = allocate_shared_tensor(size.tolist(), dtype=dtype) else: dst = torch.empty(size.tolist(), dtype=dtype) start_t = time.monotonic() data_pgs = self._data_pgs() if data_pgs is None: td.recv(dst, src=self.server_rank) else: outstanding_work = [] flattened_dst = dst.flatten() flattened_size = flattened_dst.shape[0] for idx, (pg, slice_) in enumerate( zip( data_pgs, split_almost_equally(flattened_size, num_parts=len(data_pgs)), )): outstanding_work.append( td.irecv( tensor=flattened_dst[slice_], src=self.server_rank, group=pg, tag=idx, )) for w in outstanding_work: w.wait() end_t = time.monotonic() if self.log_stats: stats_size = dst.numel() * dst.element_size() stats_time = end_t - start_t logger.debug( f"Received tensor {key} from server {self.server_rank}: " f"{stats_size:,} bytes " f"in {stats_time:,g} seconds " f"=> {stats_size / stats_time:,.0f} B/s") return dst
def load_embeddings(hf: h5py.File, out: Optional[FloatTensorType] = None) -> FloatTensorType: dataset: h5py.Dataset = hf[EMBEDDING_DATASET] if out is None: out = allocate_shared_tensor(dataset.shape, dtype=torch.float) # Needed because https://github.com/h5py/h5py/issues/870. if dataset.size > 0: dataset.read_direct(out.numpy()) return out
def load_chunk_of_edges( self, lhs_p: int, rhs_p: int, chunk_idx: int = 0, num_chunks: int = 1, ) -> EdgeList: file_path = self.get_edges_file(lhs_p, rhs_p) try: with h5py.File(file_path, "r") as hf: if hf.attrs.get(FORMAT_VERSION_ATTR, None) != FORMAT_VERSION: raise RuntimeError(f"Version mismatch in edge file {file_path}") lhs_ds = hf["lhs"] rhs_ds = hf["rhs"] rel_ds = hf["rel"] num_edges = rel_ds.len() begin = int(chunk_idx * num_edges / num_chunks) end = int((chunk_idx + 1) * num_edges / num_chunks) chunk_size = end - begin lhs = allocate_shared_tensor((chunk_size,), dtype=torch.long) rhs = allocate_shared_tensor((chunk_size,), dtype=torch.long) rel = allocate_shared_tensor((chunk_size,), dtype=torch.long) # Needed because https://github.com/h5py/h5py/issues/870. if chunk_size > 0: lhs_ds.read_direct(lhs.numpy(), source_sel=np.s_[begin:end]) rhs_ds.read_direct(rhs.numpy(), source_sel=np.s_[begin:end]) rel_ds.read_direct(rel.numpy(), source_sel=np.s_[begin:end]) lhsd = self.read_dynamic(hf, "lhsd", begin, end) rhsd = self.read_dynamic(hf, "rhsd", begin, end) return EdgeList(EntityList(lhs, lhsd), EntityList(rhs, rhsd), rel) except OSError as err: # h5py refuses to make it easy to figure out what went wrong. The errno # attribute is set to None. See https://github.com/h5py/h5py/issues/493. if f"errno = {errno.ENOENT}" in str(err): raise CouldNotLoadData() from err raise err
def read_dynamic( hf: h5py.File, key: str, begin: int, end: int, ) -> TensorList: try: offsets_ds = hf[f"{key}_offsets"] data_ds = hf[f"{key}_data"] except LookupError: return TensorList.empty(num_tensors=end - begin) offsets = allocate_shared_tensor((end - begin + 1,), dtype=torch.long) offsets_ds.read_direct(offsets.numpy(), source_sel=np.s_[begin:end + 1]) data_begin = offsets[0].item() data_end = offsets[-1].item() data = allocate_shared_tensor((data_end - data_begin,), dtype=torch.long) # Needed because https://github.com/h5py/h5py/issues/870. if data_end - data_begin > 0: data_ds.read_direct(data.numpy(), source_sel=np.s_[data_begin:data_end]) offsets -= int(offsets[0]) return TensorList(offsets, data)
def get( self, key: str, dst: Optional[torch.Tensor] = None, shared: bool = False, ) -> Optional[torch.Tensor]: """Get a tensor from the server. """ cmd_rpc = torch.tensor([GET_CMD, len(key), dst is None, 0, 0, 0], dtype=torch.long) td.send(cmd_rpc, self.server_rank) td.send(_fromstring(key), self.server_rank) if dst is None: meta = torch.full((2,), -1, dtype=torch.long) td.recv(meta, src=self.server_rank) ndim, ttype = meta if ndim.item() == -1: return None size = torch.full((ndim.item(),), -1, dtype=torch.long) td.recv(size, src=self.server_rank) dtype = _dtypes[ttype.item()] if shared: dst = allocate_shared_tensor(size.tolist(), dtype=dtype) else: dst = torch.empty(size.tolist(), dtype=dtype) start_t = time.monotonic() td.recv(dst, src=self.server_rank) end_t = time.monotonic() if self.log_stats: stats_size = dst.numel() * dst.element_size() stats_time = end_t - start_t logger.debug( f"Received tensor {key} from server {self.server_rank}: " f"{stats_size:,} bytes " f"in {stats_time:,g} seconds " f"=> {stats_size / stats_time:,.0f} B/s") return dst
def __init__( # noqa self, config: ConfigSchema, model: Optional[MultiRelationEmbedder] = None, trainer: Optional[AbstractBatchProcessor] = None, evaluator: Optional[AbstractBatchProcessor] = None, rank: Rank = SINGLE_TRAINER, subprocess_init: Optional[Callable[[], None]] = None, stats_handler: StatsHandler = NOOP_STATS_HANDLER, ): """Each epoch/pass, for each partition pair, loads in embeddings and edgelist from disk, runs HOGWILD training on them, and writes partitions back to disk. """ tag_logs_with_process_name(f"Trainer-{rank}") self.config = config if config.verbose > 0: import pprint pprint.PrettyPrinter().pprint(config.to_dict()) logger.info("Loading entity counts...") entity_storage = ENTITY_STORAGES.make_instance(config.entity_path) entity_counts: Dict[str, List[int]] = {} for entity, econf in config.entities.items(): entity_counts[entity] = [] for part in range(econf.num_partitions): entity_counts[entity].append(entity_storage.load_count(entity, part)) # Figure out how many lhs and rhs partitions we need holder = self.holder = EmbeddingHolder(config) logger.debug( f"nparts {holder.nparts_lhs} {holder.nparts_rhs} " f"types {holder.lhs_partitioned_types} {holder.rhs_partitioned_types}" ) # We know ahead of time that we wil need 1-2 storages for each embedding type, # as well as the max size of this storage (num_entities x D). # We allocate these storages n advance in `embedding_storage_freelist`. # When we need storage for an entity type, we pop it from this free list, # and then add it back when we 'delete' the embedding table. embedding_storage_freelist: Dict[ EntityName, Set[torch.FloatStorage] ] = defaultdict(set) for entity_type, counts in entity_counts.items(): max_count = max(counts) num_sides = ( (1 if entity_type in holder.lhs_partitioned_types else 0) + (1 if entity_type in holder.rhs_partitioned_types else 0) + ( 1 if entity_type in (holder.lhs_unpartitioned_types | holder.rhs_unpartitioned_types) else 0 ) ) for _ in range(num_sides): embedding_storage_freelist[entity_type].add( allocate_shared_tensor( (max_count, config.entity_dimension(entity_type)), dtype=torch.float, ).storage() ) # create the handlers, threads, etc. for distributed training if config.num_machines > 1 or config.num_partition_servers > 0: if not 0 <= rank < config.num_machines: raise RuntimeError("Invalid rank for trainer") if not td.is_available(): raise RuntimeError( "The installed PyTorch version doesn't provide " "distributed training capabilities." ) ranks = ProcessRanks.from_num_invocations( config.num_machines, config.num_partition_servers ) num_ps_groups = config.num_groups_for_partition_server groups: List[List[int]] = [ranks.trainers] # barrier group groups += [ ranks.trainers + ranks.partition_servers ] * num_ps_groups # ps groups group_idxs_for_partition_servers = range(1, len(groups)) if rank == SINGLE_TRAINER: logger.info("Setup lock server...") start_server( LockServer( num_clients=len(ranks.trainers), nparts_lhs=holder.nparts_lhs, nparts_rhs=holder.nparts_rhs, entities_lhs=holder.lhs_partitioned_types, entities_rhs=holder.rhs_partitioned_types, entity_counts=entity_counts, init_tree=config.distributed_tree_init_order, stats_handler=stats_handler, ), process_name="LockServer", init_method=config.distributed_init_method, world_size=ranks.world_size, server_rank=ranks.lock_server, groups=groups, subprocess_init=subprocess_init, ) self.bucket_scheduler = DistributedBucketScheduler( server_rank=ranks.lock_server, client_rank=ranks.trainers[rank] ) logger.info("Setup param server...") start_server( ParameterServer(num_clients=len(ranks.trainers)), process_name=f"ParamS-{rank}", init_method=config.distributed_init_method, world_size=ranks.world_size, server_rank=ranks.parameter_servers[rank], groups=groups, subprocess_init=subprocess_init, ) parameter_sharer = ParameterSharer( process_name=f"ParamC-{rank}", client_rank=ranks.parameter_clients[rank], all_server_ranks=ranks.parameter_servers, init_method=config.distributed_init_method, world_size=ranks.world_size, groups=groups, subprocess_init=subprocess_init, ) if config.num_partition_servers == -1: start_server( ParameterServer( num_clients=len(ranks.trainers), group_idxs=group_idxs_for_partition_servers, log_stats=True, ), process_name=f"PartS-{rank}", init_method=config.distributed_init_method, world_size=ranks.world_size, server_rank=ranks.partition_servers[rank], groups=groups, subprocess_init=subprocess_init, ) groups = init_process_group( rank=ranks.trainers[rank], world_size=ranks.world_size, init_method=config.distributed_init_method, groups=groups, ) trainer_group, *groups_for_partition_servers = groups self.barrier_group = trainer_group if len(ranks.partition_servers) > 0: partition_client = PartitionClient( ranks.partition_servers, groups=groups_for_partition_servers, log_stats=True, ) else: partition_client = None else: self.barrier_group = None self.bucket_scheduler = SingleMachineBucketScheduler( holder.nparts_lhs, holder.nparts_rhs, config.bucket_order, stats_handler ) parameter_sharer = None partition_client = None hide_distributed_logging() # fork early for HOGWILD threads logger.info("Creating workers...") self.num_workers = get_num_workers(config.workers) self.pool = create_pool( self.num_workers, subprocess_name=f"TWorker-{rank}", subprocess_init=subprocess_init, ) checkpoint_manager = CheckpointManager( config.checkpoint_path, rank=rank, num_machines=config.num_machines, partition_client=partition_client, subprocess_name=f"BackgRW-{rank}", subprocess_init=subprocess_init, ) self.checkpoint_manager = checkpoint_manager checkpoint_manager.register_metadata_provider(ConfigMetadataProvider(config)) if rank == 0: checkpoint_manager.write_config(config) num_edge_chunks = get_num_edge_chunks(config) self.iteration_manager = IterationManager( config.num_epochs, config.edge_paths, num_edge_chunks, iteration_idx=checkpoint_manager.checkpoint_version, ) checkpoint_manager.register_metadata_provider(self.iteration_manager) logger.info("Initializing global model...") if model is None: model = make_model(config) model.share_memory() loss_fn = LOSS_FUNCTIONS.get_class(config.loss_fn)(margin=config.margin) relation_weights = [relation.weight for relation in config.relations] if trainer is None: trainer = Trainer( model_optimizer=make_optimizer(config, model.parameters(), False), loss_fn=loss_fn, relation_weights=relation_weights, ) if evaluator is None: eval_overrides = {} if config.eval_num_batch_negs is not None: eval_overrides["num_batch_negs"] = config.eval_num_batch_negs if config.eval_num_uniform_negs is not None: eval_overrides["num_uniform_negs"] = config.eval_num_uniform_negs evaluator = RankingEvaluator( loss_fn=loss_fn, relation_weights=relation_weights, overrides=eval_overrides, ) if config.init_path is not None: self.loadpath_manager = CheckpointManager(config.init_path) else: self.loadpath_manager = None # load model from checkpoint or loadpath, if available state_dict, optim_state = checkpoint_manager.maybe_read_model() if state_dict is None and self.loadpath_manager is not None: state_dict, optim_state = self.loadpath_manager.maybe_read_model() if state_dict is not None: model.load_state_dict(state_dict, strict=False) if optim_state is not None: trainer.model_optimizer.load_state_dict(optim_state) logger.debug("Loading unpartitioned entities...") for entity in holder.lhs_unpartitioned_types | holder.rhs_unpartitioned_types: count = entity_counts[entity][0] s = embedding_storage_freelist[entity].pop() dimension = config.entity_dimension(entity) embs = torch.FloatTensor(s).view(-1, dimension)[:count] embs, optimizer = self._load_embeddings(entity, UNPARTITIONED, out=embs) holder.unpartitioned_embeddings[entity] = embs trainer.unpartitioned_optimizers[entity] = optimizer # start communicating shared parameters with the parameter server if parameter_sharer is not None: shared_parameters: Set[int] = set() for name, param in model.named_parameters(): if id(param) in shared_parameters: continue shared_parameters.add(id(param)) key = f"model.{name}" logger.info( f"Adding {key} ({param.numel()} params) to parameter server" ) parameter_sharer.set_param(key, param.data) for entity, embs in holder.unpartitioned_embeddings.items(): key = f"entity.{entity}" logger.info(f"Adding {key} ({embs.numel()} params) to parameter server") parameter_sharer.set_param(key, embs.data) # store everything in self self.model = model self.trainer = trainer self.evaluator = evaluator self.rank = rank self.entity_counts = entity_counts self.embedding_storage_freelist = embedding_storage_freelist self.stats_handler = stats_handler self.strict = False