def release_bucket(self, bucket: Bucket) -> None: """ Releases the lock on lhs and rhs, and marks this pair as done. """ if bucket.lhs is not None: self.active.pop(bucket) log("lockserver release %s: active= %s" % (bucket, self.active))
def make_model(config: ConfigSchema) -> MultiRelationEmbedder: if config.dynamic_relations: if len(config.relations) != 1: raise RuntimeError( "Dynamic relations are enabled, so there should only be one " "entry in config.relations with config for all relations.") if maybe_old_entity_path(config.entity_path): log("WARNING: It may be that your entity path contains files using " "the old format. See D14241362 for how to update them.") try: with open( os.path.join(config.entity_path, "dynamic_rel_count.txt"), "rt") as tf: num_dynamic_rels = int(tf.read().strip()) except FileNotFoundError: raise RuntimeError( "Dynamic relations are enabled, so there should be a file called " "dynamic_rel_count.txt in the entity path with their count.") else: num_dynamic_rels = 0 if config.num_batch_negs > 0 and config.batch_size % config.num_batch_negs != 0: raise RuntimeError( "Batch size (%d) must be a multiple of num_batch_negs (%d)" % (config.batch_size, config.num_batch_negs)) lhs_operators: List[Optional[Union[AbstractOperator, AbstractDynamicOperator]]] = [] rhs_operators: List[Optional[Union[AbstractOperator, AbstractDynamicOperator]]] = [] for r in config.relations: lhs_operators.append( instantiate_operator(r.operator, Side.LHS, num_dynamic_rels, config.dimension)) rhs_operators.append( instantiate_operator(r.operator, Side.RHS, num_dynamic_rels, config.dimension)) try: comparator_class = COMPARATORS[config.comparator] except KeyError: raise NotImplementedError("Unknown comparator: %s" % config.comparator) comparator = comparator_class() if config.bias: comparator = BiasedComparator(comparator) return MultiRelationEmbedder( config.dimension, config.relations, config.entities, num_uniform_negs=config.num_uniform_negs, num_batch_negs=config.num_batch_negs, lhs_operators=lhs_operators, rhs_operators=rhs_operators, comparator=comparator, global_emb=config.global_emb, max_norm=config.max_norm, num_dynamic_rels=num_dynamic_rels, )
def __init__(self, config: ConfigSchema, filter_paths: List[str]): if len(config.relations) != 1 or len(config.entities) != 1: raise RuntimeError("Filtered ranking evaluation should only be used " "with dynamic relations and one entity type.") if not config.relations[0].all_negs: raise RuntimeError("Filtered Eval can only be done with all negatives.") entity, = config.entities.values() if entity.featurized: raise RuntimeError("Entity cannot be featurized for filtered eval.") if entity.num_partitions > 1: raise RuntimeError("Entity cannot be partitioned for filtered eval.") self.lhs_map: Dict[Tuple[int, int], List[int]] = defaultdict(list) self.rhs_map: Dict[Tuple[int, int], List[int]] = defaultdict(list) for path in filter_paths: log("Building links map from path %s" % path) e_reader = EdgeReader(path) # Assume unpartitioned. lhs, rhs, rel = e_reader.read(Partition(0), Partition(0)) num_edges = lhs.size(0) for i in range(num_edges): # Assume non-featurized. cur_lhs = lhs[i].collapse(is_featurized=False).item() cur_rel = rel[i].item() # Assume non-featurized. cur_rhs = rhs[i].collapse(is_featurized=False).item() self.lhs_map[cur_lhs, cur_rel].append(cur_rhs) self.rhs_map[cur_rhs, cur_rel].append(cur_lhs) log("Done building links map from path %s" % path)
def __init__(self, path: str) -> None: if not os.path.isdir(path): raise RuntimeError("Invalid edge dir: %s" % path) if maybe_old_edge_path(path): log("WARNING: It may be that one of your edge paths contains files " "using the old format. See D14241362 for how to update them.") self.path: str = path
def share_model_params(self, model: nn.Module) -> None: shared_parameters: Set[int] = set() for k, v in ModuleStateDict(model.state_dict()).items(): if v._cdata not in shared_parameters: shared_parameters.add(v._cdata) log("Adding %s (%d params) to parameter server" % (k, v.numel())) self.set_param(k, v.data)
def acquire_bucket( self, rank: Rank, maybe_old_bucket: Optional[Bucket] = None, ) -> Tuple[Optional[Bucket], int]: """ Finds a (lhs, rhs) partition pair that has not already been acquired this epoch, and where neither the lhs nor rhs partitions are currently locked. Locks this lhs and rhs until `release_pair` is called. Will try to find a pair that has the same lhs (if not, rhs) as old_bucket. If no pair is available, returns None. Returns: pair: a (lhs, rhs) partition pair. lhs and rhs are locked until `release_pair` is called. If no pair is available, None is returned. remaining: The number of pairs remaining. When this is 0 then the epoch is done. """ remaining = len(self.buckets) - len(self.done) if maybe_old_bucket is not None: # The linter isn't smart enough to figure out that the closure is # capturing a non-None value, thus alias it to a new variable, which # will get a non-Optional type. old_bucket = maybe_old_bucket # The linter isn't too smart around closures... ordered_buckets = sorted( self.buckets, key=lambda x: -(2 * (x.lhs == old_bucket.lhs) + (x.rhs == old_bucket.rhs))) else: ordered_buckets = self.buckets locked_partitions = { bucket.get_partition(side): rank for bucket, rank in self.active.items() for side in self.locked_sides } for pair in ordered_buckets: if (pair not in self.done and self._can_acquire( rank, pair.lhs, locked_partitions, Side.LHS) and self._can_acquire(rank, pair.rhs, locked_partitions, Side.RHS) and (self.initialized_partitions is None or pair.lhs in self.initialized_partitions or pair.rhs in self.initialized_partitions)): self.active[pair] = rank self.done.add(pair) if self.initialized_partitions is not None: self.initialized_partitions.add(pair.lhs) self.initialized_partitions.add(pair.rhs) log("lockserver %d acquire %s: active= %s" % (rank, pair, self.active)) return pair, remaining return None, remaining
def __init__( self, path: str, rank: Rank = -1, num_machines: int = 1, background: bool = False, partition_client: Optional[PartitionClient] = None, ) -> None: """ Args: - path : path to the folder containing checkpoints. - background: if True, will do prefetch and store in a background process """ if maybe_old_checkpoint_path(path): log("WARNING: It may be that your checkpoint path (or your init " "path) contains files using the old format. See D14241362 for " "how to update them.") self.path: str = path self.dirty: Set[Tuple[EntityName, Partition]] = set() self.rank: Rank = rank self.num_machines: int = num_machines if self.rank == 0: os.makedirs(self.path, exist_ok=True) # FIXME: there's a slight danger here, say that a multi-machine job fails # after a few versions, and then it reruns but one of the write_version=False # machines has cached the metadata and thinks it doesn't exist, then it # will expect checkpoint_version=0 and fail. try: with open(os.path.join(self.path, VERSION_FILE), "rt") as tf: version_string = tf.read().strip() except FileNotFoundError: self.checkpoint_version = 0 else: # On some distributed filesystems creating the file (with an empty # content) and writing "0" to it are separate actions thus a race # condition could occur where trainers see the file as empty. if len(version_string) == 0: self.checkpoint_version = 0 else: self.checkpoint_version = int(version_string) self.background: bool = background if self.background: self.pool: mp.Pool = create_pool(1) # FIXME In py-3.7 switch to typing.OrderedDict[str, AsyncResult]. self.outstanding: OrderedDict = OrderedDict() self.prefetched: Dict[str, Tuple[FloatTensorType, Optional[OptimizerStateDict]]] = {} self.partition_client = partition_client self.metadata_providers: List[MetadataProvider] = []
def _sync(self, sync_path: Optional[str] = None) -> None: assert self.background vlog("CheckpointManager=>_sync( %s )" % sync_path) vlog("outstanding= %s" % set(self.outstanding)) while len(self.outstanding) > 0: path, future_res = self.outstanding.popitem(last=False) res = get_async_result(future_res, self.pool) if res is not None: log("Setting prefetched %s; %d outstanding" % (path, len(self.outstanding))) self.prefetched[path] = res if sync_path is not None and path == sync_path: break
def read( self, lhs_p: Partition, rhs_p: Partition, chunk_idx: int = 0, num_chunks: int = 1, ) -> EdgeList: file_path = os.path.join(self.path, "edges_%d_%d.h5" % (lhs_p, rhs_p)) assert os.path.exists(file_path), "%s does not exist" % file_path with h5py.File(file_path, 'r') as hf: if FORMAT_VERSION_ATTR not in hf.attrs: log("WARNING: It may be that one of your edge paths contains " "files using the old format. See D14241362 for how to " "update them.") elif hf.attrs[FORMAT_VERSION_ATTR] != FORMAT_VERSION: raise RuntimeError("Version mismatch in edge file %s" % file_path) lhs_ds = hf['lhs'] rhs_ds = hf['rhs'] rel_ds = hf['rel'] num_edges = rel_ds.len() begin = int(chunk_idx * num_edges / num_chunks) end = int((chunk_idx + 1) * num_edges / num_chunks) chunk_size = end - begin lhs = torch.empty((chunk_size,), dtype=torch.long) rhs = torch.empty((chunk_size,), dtype=torch.long) rel = torch.empty((chunk_size,), dtype=torch.long) # Needed because https://github.com/h5py/h5py/issues/870. if chunk_size > 0: lhs_ds.read_direct(lhs.numpy(), source_sel=np.s_[begin:end]) rhs_ds.read_direct(rhs.numpy(), source_sel=np.s_[begin:end]) rel_ds.read_direct(rel.numpy(), source_sel=np.s_[begin:end]) lhsd = self.read_dynamic(hf, 'lhsd', begin, end) rhsd = self.read_dynamic(hf, 'rhsd', begin, end) return EdgeList(EntityList(lhs, lhsd), EntityList(rhs, rhsd), rel)
def init_process_group( init_method: Optional[str], world_size: int, rank: Rank, groups: List[List[Rank]], backend: str = "gloo", ) -> List['td.ProcessGroup']: # With the THD backend there were no timeouts so high variance in # execution time between trainers was not a problem. With the new c10d # implementation we do have to take timeouts into account. To simulate # the old behavior we use a ridiculously high default timeout. timeout = timedelta(days=365) log("init_process_group start") if init_method is None: raise RuntimeError( "distributed_init_method must be set when num_machines > 1") td.init_process_group(backend, init_method=init_method, world_size=world_size, rank=rank, timeout=timeout) log("init_process_group creating groups") group_objs = [] for group in groups: group_objs.append(td.new_group(group, timeout=timeout)) log("init_process_group done") return group_objs
def prepare_negatives( self, pos_input: EntityList, pos_embs: FloatTensorType, module: AbstractEmbedding, type_: Negatives, num_uniform_neg: int, rel: Union[int, LongTensorType], entity_type: str, operator: Union[None, AbstractOperator, AbstractDynamicOperator], ) -> Tuple[FloatTensorType, Mask]: """Given some chunked positives, set up chunks of negatives. This function operates on one side (left-hand or right-hand) at a time. It takes all the information about the positives on that side (the original input value, the corresponding embeddings, and the module used to convert one to the other). It then produces negatives for that side according to the specified mode. The positive embeddings come in in chunked form and the negatives are produced within each of these chunks. The negatives can be either none, or the positives from the same chunk, or all the possible entities. In the second mode, uniformly-sampled entities can also be appended to the per-chunk negatives (each chunk having a different sample). This function returns both the chunked embeddings of the negatives and a mask of the same size as the chunked positives-vs-negatives scores, whose non-zero elements correspond to the scores that must be ignored. """ num_pos = len(pos_input) num_chunks, chunk_size, dim = match_shape(pos_embs, -1, -1, -1) last_chunk_size = num_pos - (num_chunks - 1) * chunk_size ignore_mask: Mask = [] if type_ is Negatives.NONE: neg_embs = torch.empty((num_chunks, 0, dim)) elif type_ is Negatives.UNIFORM: uniform_neg_embs = module.sample_entities(num_chunks, num_uniform_neg) neg_embs = self.adjust_embs( uniform_neg_embs, rel, entity_type, operator, ) elif type_ is Negatives.BATCH_UNIFORM: neg_embs = pos_embs if num_uniform_neg > 0: try: uniform_neg_embs = module.sample_entities( num_chunks, num_uniform_neg) except NotImplementedError: pass # only use pos_embs i.e. batch negatives else: neg_embs = torch.cat([ pos_embs, self.adjust_embs( uniform_neg_embs, rel, entity_type, operator, ) ], dim=1) chunk_indices = torch.arange(chunk_size, dtype=torch.long) last_chunk_indices = chunk_indices[:last_chunk_size] # Ignore scores between positive pairs. ignore_mask.append( (slice(num_chunks - 1), chunk_indices, chunk_indices)) ignore_mask.append((-1, last_chunk_indices, last_chunk_indices)) # In the last chunk, ignore the scores between the positives that # are not padding (i.e., the first last_chunk_size ones) and the # negatives that are padding (i.e., all of them except the first # last_chunk_size ones). Stop the last slice at chunk_size so that # it doesn't also affect the uniformly-sampled negatives. ignore_mask.append( (-1, slice(last_chunk_size), slice(last_chunk_size, chunk_size))) elif type_ is Negatives.ALL: pos_input = pos_input.to_tensor() neg_embs = self.adjust_embs( module.get_all_entities().expand(num_chunks, -1, dim), rel, entity_type, operator, ) if num_uniform_neg > 0: log("WARNING: Adding uniform negatives makes no sense " "when already using all negatives") chunk_indices = torch.arange(chunk_size, dtype=torch.long) last_chunk_indices = chunk_indices[:last_chunk_size] # Ignore scores between positive pairs: since the i-th such pair has # the pos_input[i] entity on this side, ignore_mask[i, pos_input[i]] # must be set to 1 for every i. This becomes slightly more tricky as # the rows may be wrapped into multiple chunks (the last of which # may be smaller). ignore_mask.append(( torch.arange(num_chunks - 1, dtype=torch.long).unsqueeze(1), chunk_indices.unsqueeze(0), pos_input[:-last_chunk_size].view(num_chunks - 1, chunk_size), )) ignore_mask.append( (-1, last_chunk_indices, pos_input[-last_chunk_size:])) else: raise NotImplementedError("Unknown negative type %s" % type_) return neg_embs, ignore_mask
def swap_partitioned_embeddings( old_b: Optional[Bucket], new_b: Optional[Bucket], ): # 0. given the old and new buckets, construct data structures to keep # track of old and new embedding (entity, part) tuples io_bytes = 0 log("Swapping partitioned embeddings %s %s" % (old_b, new_b)) types = ([(e, Side.LHS) for e in lhs_partitioned_types] + [(e, Side.RHS) for e in rhs_partitioned_types]) old_parts = {(e, old_b.get_partition(side)): side for e, side in types if old_b is not None} new_parts = {(e, new_b.get_partition(side)): side for e, side in types if new_b is not None} to_checkpoint = set(old_parts) - set(new_parts) preserved = set(old_parts) & set(new_parts) # 1. checkpoint embeddings that will not be used in the next pair # if old_b is not None: # there are previous embeddings to checkpoint log("Writing partitioned embeddings") for entity, part in to_checkpoint: side = old_parts[(entity, part)] vlog("Checkpointing (%s %d %s)" % (entity, part, side.pick("lhs", "rhs"))) embs = model.get_embeddings(entity, side) optim_key = (entity, part) optim_state = OptimizerStateDict(trainer.entity_optimizers[optim_key].state_dict()) io_bytes += embs.numel() * embs.element_size() # ignore optim state checkpoint_manager.write(entity, part, embs.detach(), optim_state) if optim_key in trainer.entity_optimizers: del trainer.entity_optimizers[optim_key] # these variables are holding large objects; let them be freed del embs del optim_state bucket_scheduler.release_bucket(old_b) # 2. copy old embeddings that will be used in the next pair # into a temporary dictionary # tmp_emb = {x: model.get_embeddings(x[0], old_parts[x]) for x in preserved} for entity, _ in types: model.clear_embeddings(entity, Side.LHS) model.clear_embeddings(entity, Side.RHS) if new_b is None: # there are no new embeddings to load return io_bytes # 3. load new embeddings into the model/optimizer, either from disk # or the temporary dictionary # log("Loading entities") for entity, side in types: part = new_b.get_partition(side) part_key = (entity, part) if part_key in tmp_emb: vlog("Loading (%s, %d) from preserved" % (entity, part)) embs, optim_state = tmp_emb[part_key], None else: vlog("Loading (%s, %d)" % (entity, part)) force_dirty = bucket_scheduler.check_and_set_dirty(entity, part) embs, optim_state = load_embeddings( entity, part, strict=strict, force_dirty=force_dirty) io_bytes += embs.numel() * embs.element_size() # ignore optim state model.set_embeddings(entity, embs, side) tmp_emb[part_key] = embs optim_key = (entity, part) if optim_key not in trainer.entity_optimizers: vlog("Resetting optimizer %s" % (optim_key,)) optimizer = make_optimizer([embs], True) if optim_state is not None: vlog("Setting optim state") optimizer.load_state_dict(optim_state) trainer.entity_optimizers[optim_key] = optimizer return io_bytes
def train_and_report_stats( config: ConfigSchema, model: Optional[MultiRelationEmbedder] = None, trainer: Optional[AbstractBatchProcessor] = None, evaluator: Optional[AbstractBatchProcessor] = None, rank: Rank = RANK_ZERO, ) -> Generator[Tuple[int, Optional[Stats], Stats, Optional[Stats]], None, None]: """Each epoch/pass, for each partition pair, loads in embeddings and edgelist from disk, runs HOGWILD training on them, and writes partitions back to disk. """ if config.verbose > 0: import pprint pprint.PrettyPrinter().pprint(config.to_dict()) log("Loading entity counts...") if maybe_old_entity_path(config.entity_path): log("WARNING: It may be that your entity path contains files using the " "old format. See D14241362 for how to update them.") entity_counts: Dict[str, List[int]] = {} for entity, econf in config.entities.items(): entity_counts[entity] = [] for part in range(econf.num_partitions): with open(os.path.join( config.entity_path, "entity_count_%s_%d.txt" % (entity, part) ), "rt") as tf: entity_counts[entity].append(int(tf.read().strip())) # Figure out how many lhs and rhs partitions we need nparts_lhs, lhs_partitioned_types = get_partitioned_types(config, Side.LHS) nparts_rhs, rhs_partitioned_types = get_partitioned_types(config, Side.RHS) vlog("nparts %d %d types %s %s" % (nparts_lhs, nparts_rhs, lhs_partitioned_types, rhs_partitioned_types)) total_buckets = nparts_lhs * nparts_rhs sync: AbstractSynchronizer bucket_scheduler: AbstractBucketScheduler parameter_sharer: Optional[ParameterSharer] partition_client: Optional[PartitionClient] if config.num_machines > 1: if not 0 <= rank < config.num_machines: raise RuntimeError("Invalid rank for trainer") if not td.is_available(): raise RuntimeError("The installed PyTorch version doesn't provide " "distributed training capabilities.") ranks = ProcessRanks.from_num_invocations( config.num_machines, config.num_partition_servers) if rank == RANK_ZERO: log("Setup lock server...") start_server( LockServer( num_clients=len(ranks.trainers), nparts_lhs=nparts_lhs, nparts_rhs=nparts_rhs, lock_lhs=len(lhs_partitioned_types) > 0, lock_rhs=len(rhs_partitioned_types) > 0, init_tree=config.distributed_tree_init_order, ), server_rank=ranks.lock_server, world_size=ranks.world_size, init_method=config.distributed_init_method, groups=[ranks.trainers], ) bucket_scheduler = DistributedBucketScheduler( server_rank=ranks.lock_server, client_rank=ranks.trainers[rank], ) log("Setup param server...") start_server( ParameterServer(num_clients=len(ranks.trainers)), server_rank=ranks.parameter_servers[rank], init_method=config.distributed_init_method, world_size=ranks.world_size, groups=[ranks.trainers], ) parameter_sharer = ParameterSharer( client_rank=ranks.parameter_clients[rank], all_server_ranks=ranks.parameter_servers, init_method=config.distributed_init_method, world_size=ranks.world_size, groups=[ranks.trainers], ) if config.num_partition_servers == -1: start_server( ParameterServer(num_clients=len(ranks.trainers)), server_rank=ranks.partition_servers[rank], world_size=ranks.world_size, init_method=config.distributed_init_method, groups=[ranks.trainers], ) if len(ranks.partition_servers) > 0: partition_client = PartitionClient(ranks.partition_servers) else: partition_client = None groups = init_process_group( rank=ranks.trainers[rank], world_size=ranks.world_size, init_method=config.distributed_init_method, groups=[ranks.trainers], ) trainer_group, = groups sync = DistributedSynchronizer(trainer_group) dlog = log else: sync = DummySynchronizer() bucket_scheduler = SingleMachineBucketScheduler( nparts_lhs, nparts_rhs, config.bucket_order) parameter_sharer = None partition_client = None dlog = lambda msg: None # fork early for HOGWILD threads log("Creating workers...") num_workers = get_num_workers(config.workers) pool = create_pool(num_workers) def make_optimizer(params: Iterable[torch.nn.Parameter], is_emb: bool) -> Optimizer: params = list(params) if len(params) == 0: optimizer = DummyOptimizer() elif is_emb: optimizer = RowAdagrad(params, lr=config.lr) else: if config.relation_lr is not None: lr = config.relation_lr else: lr = config.lr optimizer = Adagrad(params, lr=lr) optimizer.share_memory() return optimizer # background_io is only supported in single-machine mode background_io = config.background_io and config.num_machines == 1 checkpoint_manager = CheckpointManager( config.checkpoint_path, background=background_io, rank=rank, num_machines=config.num_machines, partition_client=partition_client, ) checkpoint_manager.register_metadata_provider(ConfigMetadataProvider(config)) checkpoint_manager.write_config(config) iteration_manager = IterationManager( config.num_epochs, config.edge_paths, config.num_edge_chunks, iteration_idx=checkpoint_manager.checkpoint_version) checkpoint_manager.register_metadata_provider(iteration_manager) if config.init_path is not None: loadpath_manager = CheckpointManager(config.init_path) else: loadpath_manager = None def load_embeddings( entity: EntityName, part: Partition, strict: bool = False, force_dirty: bool = False, ) -> Tuple[torch.nn.Parameter, Optional[OptimizerStateDict]]: if strict: embs, optim_state = checkpoint_manager.read(entity, part, force_dirty=force_dirty) else: # Strict is only false during the first iteration, because in that # case the checkpoint may not contain any data (unless a previous # run was resumed) so we fall back on initial values. embs, optim_state = checkpoint_manager.maybe_read(entity, part, force_dirty=force_dirty) if embs is None and loadpath_manager is not None: embs, optim_state = loadpath_manager.maybe_read(entity, part) if embs is None: embs, optim_state = init_embs(entity, entity_counts[entity][part], config.dimension, config.init_scale) assert embs.is_shared() return torch.nn.Parameter(embs), optim_state log("Initializing global model...") if model is None: model = make_model(config) model.share_memory() if trainer is None: trainer = Trainer( global_optimizer=make_optimizer(model.parameters(), False), loss_fn=config.loss_fn, margin=config.margin, relations=config.relations, ) if evaluator is None: evaluator = TrainingRankingEvaluator( override_num_batch_negs=config.eval_num_batch_negs, override_num_uniform_negs=config.eval_num_uniform_negs, ) eval_batch_size = round_up_to_nearest_multiple(config.batch_size, config.eval_num_batch_negs) state_dict, optim_state = checkpoint_manager.maybe_read_model() if state_dict is None and loadpath_manager is not None: state_dict, optim_state = loadpath_manager.maybe_read_model() if state_dict is not None: model.load_state_dict(state_dict, strict=False) if optim_state is not None: trainer.global_optimizer.load_state_dict(optim_state) vlog("Loading unpartitioned entities...") for entity, econfig in config.entities.items(): if econfig.num_partitions == 1: embs, optim_state = load_embeddings(entity, Partition(0)) model.set_embeddings(entity, embs, Side.LHS) model.set_embeddings(entity, embs, Side.RHS) optimizer = make_optimizer([embs], True) if optim_state is not None: optimizer.load_state_dict(optim_state) trainer.entity_optimizers[(entity, Partition(0))] = optimizer # start communicating shared parameters with the parameter server if parameter_sharer is not None: parameter_sharer.share_model_params(model) strict = False def swap_partitioned_embeddings( old_b: Optional[Bucket], new_b: Optional[Bucket], ): # 0. given the old and new buckets, construct data structures to keep # track of old and new embedding (entity, part) tuples io_bytes = 0 log("Swapping partitioned embeddings %s %s" % (old_b, new_b)) types = ([(e, Side.LHS) for e in lhs_partitioned_types] + [(e, Side.RHS) for e in rhs_partitioned_types]) old_parts = {(e, old_b.get_partition(side)): side for e, side in types if old_b is not None} new_parts = {(e, new_b.get_partition(side)): side for e, side in types if new_b is not None} to_checkpoint = set(old_parts) - set(new_parts) preserved = set(old_parts) & set(new_parts) # 1. checkpoint embeddings that will not be used in the next pair # if old_b is not None: # there are previous embeddings to checkpoint log("Writing partitioned embeddings") for entity, part in to_checkpoint: side = old_parts[(entity, part)] vlog("Checkpointing (%s %d %s)" % (entity, part, side.pick("lhs", "rhs"))) embs = model.get_embeddings(entity, side) optim_key = (entity, part) optim_state = OptimizerStateDict(trainer.entity_optimizers[optim_key].state_dict()) io_bytes += embs.numel() * embs.element_size() # ignore optim state checkpoint_manager.write(entity, part, embs.detach(), optim_state) if optim_key in trainer.entity_optimizers: del trainer.entity_optimizers[optim_key] # these variables are holding large objects; let them be freed del embs del optim_state bucket_scheduler.release_bucket(old_b) # 2. copy old embeddings that will be used in the next pair # into a temporary dictionary # tmp_emb = {x: model.get_embeddings(x[0], old_parts[x]) for x in preserved} for entity, _ in types: model.clear_embeddings(entity, Side.LHS) model.clear_embeddings(entity, Side.RHS) if new_b is None: # there are no new embeddings to load return io_bytes # 3. load new embeddings into the model/optimizer, either from disk # or the temporary dictionary # log("Loading entities") for entity, side in types: part = new_b.get_partition(side) part_key = (entity, part) if part_key in tmp_emb: vlog("Loading (%s, %d) from preserved" % (entity, part)) embs, optim_state = tmp_emb[part_key], None else: vlog("Loading (%s, %d)" % (entity, part)) force_dirty = bucket_scheduler.check_and_set_dirty(entity, part) embs, optim_state = load_embeddings( entity, part, strict=strict, force_dirty=force_dirty) io_bytes += embs.numel() * embs.element_size() # ignore optim state model.set_embeddings(entity, embs, side) tmp_emb[part_key] = embs optim_key = (entity, part) if optim_key not in trainer.entity_optimizers: vlog("Resetting optimizer %s" % (optim_key,)) optimizer = make_optimizer([embs], True) if optim_state is not None: vlog("Setting optim state") optimizer.load_state_dict(optim_state) trainer.entity_optimizers[optim_key] = optimizer return io_bytes # Start of the main training loop. for epoch_idx, edge_path_idx, edge_chunk_idx \ in iteration_manager.remaining_iterations(): log("Starting epoch %d / %d edge path %d / %d edge chunk %d / %d" % (epoch_idx + 1, iteration_manager.num_epochs, edge_path_idx + 1, iteration_manager.num_edge_paths, edge_chunk_idx + 1, iteration_manager.num_edge_chunks)) edge_reader = EdgeReader(iteration_manager.edge_path) log("edge_path= %s" % iteration_manager.edge_path) sync.barrier() dlog("Lock client new epoch...") bucket_scheduler.new_pass(is_first=iteration_manager.iteration_idx == 0) sync.barrier() remaining = total_buckets cur_b = None while remaining > 0: old_b = cur_b io_time = 0. io_bytes = 0 cur_b, remaining = bucket_scheduler.acquire_bucket() print('still in queue: %d' % remaining, file=sys.stderr) if cur_b is None: if old_b is not None: # if you couldn't get a new pair, release the lock # to prevent a deadlock! tic = time.time() io_bytes += swap_partitioned_embeddings(old_b, None) io_time += time.time() - tic time.sleep(1) # don't hammer td continue def log_status(msg, always=False): f = log if always else vlog f("%s: %s" % (cur_b, msg)) tic = time.time() io_bytes += swap_partitioned_embeddings(old_b, cur_b) current_index = \ (iteration_manager.iteration_idx + 1) * total_buckets - remaining next_b = bucket_scheduler.peek() if next_b is not None and background_io: # Ensure the previous bucket finished writing to disk. checkpoint_manager.wait_for_marker(current_index - 1) log_status("Prefetching") for entity in lhs_partitioned_types: checkpoint_manager.prefetch(entity, next_b.lhs) for entity in rhs_partitioned_types: checkpoint_manager.prefetch(entity, next_b.rhs) checkpoint_manager.record_marker(current_index) log_status("Loading edges") edges = edge_reader.read( cur_b.lhs, cur_b.rhs, edge_chunk_idx, config.num_edge_chunks) num_edges = len(edges) # this might be off in the case of tensorlist or extra edge fields io_bytes += edges.lhs.tensor.numel() * edges.lhs.tensor.element_size() io_bytes += edges.rhs.tensor.numel() * edges.rhs.tensor.element_size() io_bytes += edges.rel.numel() * edges.rel.element_size() log_status("Shuffling edges") # Fix a seed to get the same permutation every time; have it # depend on all and only what affects the set of edges. g = torch.Generator() g.manual_seed(hash((edge_path_idx, edge_chunk_idx, cur_b.lhs, cur_b.rhs))) num_eval_edges = int(num_edges * config.eval_fraction) if num_eval_edges > 0: edge_perm = torch.randperm(num_edges, generator=g) eval_edge_perm = edge_perm[-num_eval_edges:] num_edges -= num_eval_edges edge_perm = edge_perm[torch.randperm(num_edges)] else: edge_perm = torch.randperm(num_edges) # HOGWILD evaluation before training eval_stats_before: Optional[Stats] = None if num_eval_edges > 0: log_status("Waiting for workers to perform evaluation") all_eval_stats_before = pool.map(call, [ partial( process_in_batches, batch_size=eval_batch_size, model=model, batch_processor=evaluator, edges=edges, indices=eval_edge_perm[s], ) for s in split_almost_equally(eval_edge_perm.size(0), num_parts=num_workers) ]) eval_stats_before = Stats.sum(all_eval_stats_before).average() log("stats before %s: %s" % (cur_b, eval_stats_before)) io_time += time.time() - tic tic = time.time() # HOGWILD training log_status("Waiting for workers to perform training") # FIXME should we only delay if iteration_idx == 0? all_stats = pool.map(call, [ partial( process_in_batches, batch_size=config.batch_size, model=model, batch_processor=trainer, edges=edges, indices=edge_perm[s], delay=config.hogwild_delay if epoch_idx == 0 and rank > 0 else 0, ) for rank, s in enumerate(split_almost_equally(edge_perm.size(0), num_parts=num_workers)) ]) stats = Stats.sum(all_stats).average() compute_time = time.time() - tic log_status( "bucket %d / %d : Processed %d edges in %.2f s " "( %.2g M/sec ); io: %.2f s ( %.2f MB/sec )" % (total_buckets - remaining, total_buckets, num_edges, compute_time, num_edges / compute_time / 1e6, io_time, io_bytes / io_time / 1e6), always=True) log_status("%s" % stats, always=True) # HOGWILD eval after training eval_stats_after: Optional[Stats] = None if num_eval_edges > 0: log_status("Waiting for workers to perform evaluation") all_eval_stats_after = pool.map(call, [ partial( process_in_batches, batch_size=eval_batch_size, model=model, batch_processor=evaluator, edges=edges, indices=eval_edge_perm[s], ) for s in split_almost_equally(eval_edge_perm.size(0), num_parts=num_workers) ]) eval_stats_after = Stats.sum(all_eval_stats_after).average() log("stats after %s: %s" % (cur_b, eval_stats_after)) # Add train/eval metrics to queue yield current_index, eval_stats_before, stats, eval_stats_after swap_partitioned_embeddings(cur_b, None) # Distributed Processing: all machines can leave the barrier now. sync.barrier() # Write metadata: for multiple machines, write from rank-0 log("Finished epoch %d path %d pass %d; checkpointing global state." % (epoch_idx + 1, edge_path_idx + 1, edge_chunk_idx + 1)) log("My rank: %d" % rank) if rank == 0: for entity, econfig in config.entities.items(): if econfig.num_partitions == 1: embs = model.get_embeddings(entity, Side.LHS) optimizer = trainer.entity_optimizers[(entity, Partition(0))] checkpoint_manager.write( entity, Partition(0), embs.detach(), OptimizerStateDict(optimizer.state_dict())) sanitized_state_dict: ModuleStateDict = {} for k, v in ModuleStateDict(model.state_dict()).items(): if k.startswith('lhs_embs') or k.startswith('rhs_embs'): # skipping state that's an entity embedding continue sanitized_state_dict[k] = v log("Writing metadata...") checkpoint_manager.write_model( sanitized_state_dict, OptimizerStateDict(trainer.global_optimizer.state_dict()), ) log("Writing the checkpoint...") checkpoint_manager.write_new_version(config) dlog("Waiting for other workers to write their parts of the checkpoint: rank %d" % rank) sync.barrier() dlog("All parts of the checkpoint have been written") log("Switching to new checkpoint version...") checkpoint_manager.switch_to_new_version() dlog("Waiting for other workers to switch to the new checkpoint version: rank %d" % rank) sync.barrier() dlog("All workers have switched to the new checkpoint version") # After all the machines have finished committing # checkpoints, we remove the old checkpoints. checkpoint_manager.remove_old_version(config) # now we're sure that all partition files exist, # so be strict about loading them strict = True # quiescence pool.close() pool.join() sync.barrier() checkpoint_manager.close() if loadpath_manager is not None: loadpath_manager.close() # FIXME join distributed workers (not really necessary) log("Exiting")
def _client_thread_loop( client_rank: Rank, all_server_ranks: List[Rank], q: mp.Queue, errq: mp.Queue, init_method: Optional[str], world_size: int, groups: List[List[Rank]], subprocess_init: Optional[Callable[[], None]] = None, max_bandwidth: float = 1e8, min_sleep_time: float = 0.01, ) -> None: try: if subprocess_init is not None: subprocess_init() init_process_group( rank=client_rank, init_method=init_method, world_size=world_size, groups=groups, ) params = {} clients = [ GradientParameterClient(server_rank) for server_rank in all_server_ranks ] log_time, log_rounds, log_bytes = time.time(), 0, 0 # thread loop: # 1. check for a command from the main process # 2. update (push and pull) each parameter in my list of parameters # 3. if we're going to fast, sleep for a while while True: tic = time.time() bytes_transferred = 0 try: data = q.get(timeout=0.01) cmd, args = data if cmd == "params": params[args[0]] = args[1] log_time, log_rounds, log_bytes = time.time(), 0, 0 elif cmd == "join": for client in clients: client.join() break except queue.Empty: pass for k, v in params.items(): param_size = v.numel() * v.element_size() bytes_transferred += param_size if param_size > MIN_BYTES_TO_SHARD: chunks = v.chunk(len(clients), dim=0) for client, chunk in zip(clients, chunks): client.update(k, chunk) else: client_idx = hash(k) % len(clients) clients[client_idx].update(k, v) log_bytes += bytes_transferred log_rounds += 1 log_delta = time.time() - log_time if params and log_delta > 60: log("Parameter client synced %d rounds %g GB in %g s ( %g s/round , %g GB/s)" % (log_rounds, log_bytes / 1e9, log_delta, log_delta / log_rounds, log_bytes / log_delta / 1e9)) log_time, log_rounds, log_bytes = time.time(), 0, 0 comm_time = time.time() - tic sleep_time = max(bytes_transferred / max_bandwidth - comm_time, min_sleep_time) time.sleep(sleep_time) except BaseException as e: traceback.print_exc() errq.put(e) raise
def do_eval_and_report_stats( config: ConfigSchema, model: Optional[MultiRelationEmbedder] = None, evaluator: Optional[AbstractBatchProcessor] = None, subprocess_init: Optional[Callable[[], None]] = None, ) -> Generator[Tuple[Optional[int], Optional[Bucket], Stats], None, None]: """Computes eval metrics (mr/mrr/r1/r10/r50) for a checkpoint with trained embeddings. """ if evaluator is None: evaluator = RankingEvaluator() if config.verbose > 0: import pprint pprint.PrettyPrinter().pprint(config.to_dict()) checkpoint_manager = CheckpointManager(config.checkpoint_path) def load_embeddings(entity: EntityName, part: Partition) -> torch.nn.Parameter: embs, _ = checkpoint_manager.read(entity, part) assert embs.is_shared() return torch.nn.Parameter(embs) nparts_lhs, lhs_partitioned_types = get_partitioned_types(config, Side.LHS) nparts_rhs, rhs_partitioned_types = get_partitioned_types(config, Side.RHS) num_workers = get_num_workers(config.workers) pool = create_pool(num_workers, subprocess_init=subprocess_init) if model is None: model = make_model(config) model.share_memory() state_dict, _ = checkpoint_manager.maybe_read_model() if state_dict is not None: model.load_state_dict(state_dict, strict=False) model.eval() for entity, econfig in config.entities.items(): if econfig.num_partitions == 1: embs = load_embeddings(entity, Partition(0)) model.set_embeddings(entity, embs, Side.LHS) model.set_embeddings(entity, embs, Side.RHS) all_stats: List[Stats] = [] for edge_path_idx, edge_path in enumerate(config.edge_paths): log("Starting edge path %d / %d (%s)" % (edge_path_idx + 1, len(config.edge_paths), edge_path)) edge_reader = EdgeReader(edge_path) all_edge_path_stats = [] last_lhs, last_rhs = None, None for bucket in create_buckets_ordered_lexicographically( nparts_lhs, nparts_rhs): tic = time.time() # log("%s: Loading entities" % (bucket,)) if last_lhs != bucket.lhs: for e in lhs_partitioned_types: model.clear_embeddings(e, Side.LHS) embs = load_embeddings(e, bucket.lhs) model.set_embeddings(e, embs, Side.LHS) if last_rhs != bucket.rhs: for e in rhs_partitioned_types: model.clear_embeddings(e, Side.RHS) embs = load_embeddings(e, bucket.rhs) model.set_embeddings(e, embs, Side.RHS) last_lhs, last_rhs = bucket.lhs, bucket.rhs # log("%s: Loading edges" % (bucket,)) edges = edge_reader.read(bucket.lhs, bucket.rhs) num_edges = len(edges) load_time = time.time() - tic tic = time.time() # log("%s: Launching and waiting for workers" % (bucket,)) future_all_bucket_stats = pool.map_async(call, [ partial( process_in_batches, batch_size=config.batch_size, model=model, batch_processor=evaluator, edges=edges[s], ) for s in split_almost_equally(num_edges, num_parts=num_workers) ]) all_bucket_stats = \ get_async_result(future_all_bucket_stats, pool) compute_time = time.time() - tic log("%s: Processed %d edges in %.2g s (%.2gM/sec); load time: %.2g s" % (bucket, num_edges, compute_time, num_edges / compute_time / 1e6, load_time)) total_bucket_stats = Stats.sum(all_bucket_stats) all_edge_path_stats.append(total_bucket_stats) mean_bucket_stats = total_bucket_stats.average() log("Stats for edge path %d / %d, bucket %s: %s" % (edge_path_idx + 1, len( config.edge_paths), bucket, mean_bucket_stats)) yield edge_path_idx, bucket, mean_bucket_stats total_edge_path_stats = Stats.sum(all_edge_path_stats) all_stats.append(total_edge_path_stats) mean_edge_path_stats = total_edge_path_stats.average() log("") log("Stats for edge path %d / %d: %s" % (edge_path_idx + 1, len(config.edge_paths), mean_edge_path_stats)) log("") yield edge_path_idx, None, mean_edge_path_stats mean_stats = Stats.sum(all_stats).average() log("") log("Stats: %s" % mean_stats) log("") yield None, None, mean_stats pool.close() pool.join()