def create_layer_of_buckets( nparts_lhs: int, nparts_rhs: int, layer_idx: int, *, generator: random.Random, ) -> List[Bucket]: """Return the layer of #LHS x #RHS matrix of the given index The i-th layer contains the buckets (lhs, rhs) such that min(lhs, rhs) == i. Buckets that are one the transpose of the other will be consecutive. Other than that, the order is random. """ layer_p = Partition(layer_idx) pairs = [[Bucket(layer_p, layer_p)]] for idx in range(layer_idx + 1, max(nparts_lhs, nparts_rhs)): p = Partition(idx) pair = [] if p < nparts_lhs: pair.append(Bucket(p, layer_p)) if p < nparts_rhs: pair.append(Bucket(layer_p, p)) generator.shuffle(pair) pairs.append(pair) generator.shuffle(pairs) return [b for p in pairs for b in p]
def __init__(self, config: ConfigSchema, filter_paths: List[str]): super().__init__() if len(config.relations) != 1 or len(config.entities) != 1: raise RuntimeError("Filtered ranking evaluation should only be used " "with dynamic relations and one entity type.") if not config.relations[0].all_negs: raise RuntimeError("Filtered Eval can only be done with all negatives.") entity, = config.entities.values() if entity.featurized: raise RuntimeError("Entity cannot be featurized for filtered eval.") if entity.num_partitions > 1: raise RuntimeError("Entity cannot be partitioned for filtered eval.") self.lhs_map: Dict[Tuple[int, int], List[int]] = defaultdict(list) self.rhs_map: Dict[Tuple[int, int], List[int]] = defaultdict(list) for path in filter_paths: logger.info(f"Building links map from path {path}") e_reader = EDGELIST_READERS.make_instance(path) # Assume unpartitioned. edges = e_reader.read_edgelist(Partition(0), Partition(0)) for idx in range(len(edges)): # Assume non-featurized. cur_lhs = int(edges.lhs.to_tensor()[idx]) # Assume dynamic relations. cur_rel = int(edges.rel[idx]) # Assume non-featurized. cur_rhs = int(edges.rhs.to_tensor()[idx]) self.lhs_map[cur_lhs, cur_rel].append(cur_rhs) self.rhs_map[cur_rhs, cur_rel].append(cur_lhs) logger.info(f"Done building links map from path {path}")
def __init__(self, config: ConfigSchema, filter_paths: List[str]): if len(config.relations) != 1 or len(config.entities) != 1: raise RuntimeError("Filtered ranking evaluation should only be used " "with dynamic relations and one entity type.") if not config.relations[0].all_negs: raise RuntimeError("Filtered Eval can only be done with all negatives.") entity, = config.entities.values() if entity.featurized: raise RuntimeError("Entity cannot be featurized for filtered eval.") if entity.num_partitions > 1: raise RuntimeError("Entity cannot be partitioned for filtered eval.") self.lhs_map: Dict[Tuple[int, int], List[int]] = defaultdict(list) self.rhs_map: Dict[Tuple[int, int], List[int]] = defaultdict(list) for path in filter_paths: log("Building links map from path %s" % path) e_reader = EdgeReader(path) # Assume unpartitioned. lhs, rhs, rel = e_reader.read(Partition(0), Partition(0)) num_edges = lhs.size(0) for i in range(num_edges): # Assume non-featurized. cur_lhs = lhs[i].collapse(is_featurized=False).item() cur_rel = rel[i].item() # Assume non-featurized. cur_rhs = rhs[i].collapse(is_featurized=False).item() self.lhs_map[cur_lhs, cur_rel].append(cur_rhs) self.rhs_map[cur_rhs, cur_rel].append(cur_lhs) log("Done building links map from path %s" % path)
def create_buckets_ordered_lexicographically( nparts_lhs: int, nparts_rhs: int, ) -> List[Bucket]: """Return buckets in increasing LHS and, for the same LHS, in increasing RHS """ buckets = [ Bucket(Partition(lhs), Partition(rhs)) for lhs in range(nparts_lhs) for rhs in range(nparts_rhs) ] return buckets
def write_new_version( self, config: ConfigSchema, entity_counts: Dict[EntityName, List[int]], embedding_storage_freelist: Dict[EntityName, Set[torch.FloatStorage]], ) -> None: metadata = self.collect_metadata() new_version = self._version(True) if self.partition_client is not None: for entity, econf in config.entities.items(): for part in range(self.rank, econf.num_partitions, self.num_machines): logger.debug(f"Getting {entity} {part}") count = entity_counts[entity][part] s = next(iter(embedding_storage_freelist[entity])) out = torch.FloatTensor(s).view(-1, config.dimension)[:count] embs, serialized_optim_state = self.partition_client.get( EntityName(entity), Partition(part), out=out) logger.debug(f"Done getting {entity} {part}") logger.debug(f"Saving {entity} {part} v{new_version}") self.storage.save_entity_partition(new_version, entity, part, embs, serialized_optim_state, metadata) logger.debug(f"Done saving {entity} {part} v{new_version}")
def new_pass(self, is_first: bool = False) -> None: """Start a new epoch of training.""" self.active = {} self.done = set() self.dirty = set() if self.init_tree and is_first: self.initialized_partitions = {Partition(0)} else: self.initialized_partitions = None
def write_new_version(self, config: ConfigSchema) -> None: if self.background: self._sync() metadata = self.collect_metadata() new_version = self._version(True) if self.partition_client is not None: for entity, econf in config.entities.items(): for part in range(self.rank, econf.num_partitions, self.num_machines): vlog("Rank %d: getting %s %d" % (self.rank, entity, part)) embs, optim_state = \ self.partition_client.get(EntityName(entity), Partition(part)) vlog("Rank %d: saving %s %d to disk" % (self.rank, entity, part)) new_file_path = os.path.join( self.path, "embeddings_%s_%d.v%d.h5" % (entity, part, new_version)) save_entity_partition(new_file_path, embs, optim_state, metadata)
def write_new_version(self, config: ConfigSchema) -> None: if self.background: self._sync() metadata = self.collect_metadata() new_version = self._version(True) if self.partition_client is not None: for entity, econf in config.entities.items(): for part in range(self.rank, econf.num_partitions, self.num_machines): logger.debug(f"Getting {entity} {part}") embs, serialized_optim_state = \ self.partition_client.get(EntityName(entity), Partition(part)) logger.debug(f"Done getting {entity} {part}") logger.debug(f"Saving {entity} {part} v{new_version}") self.storage.save_entity_partition( new_version, entity, part, embs, serialized_optim_state, metadata) logger.debug(f"Done saving {entity} {part} v{new_version}")
def write_new_version(self, config: ConfigSchema) -> None: if self.background: self._sync() metadata = self.collect_metadata() new_version = self._version(True) if self.partition_client is not None: for entity, econf in config.entities.items(): for part in range(self.rank, econf.num_partitions, self.num_machines): logger.debug(f"Getting {entity} {part}") embs, optim_state = \ self.partition_client.get(EntityName(entity), Partition(part)) logger.debug(f"Done getting {entity} {part}") new_file_path = os.path.join( self.path, f"embeddings_{entity}_{part}.v{new_version}.h5") logger.debug(f"Saving {entity} {part} to {new_file_path}") save_entity_partition(new_file_path, embs, optim_state, metadata) logger.debug( f"Done saving {entity} {part} to {new_file_path}")
def create_buckets_ordered_by_affinity( nparts_lhs: int, nparts_rhs: int, *, generator: random.Random, ) -> List[Bucket]: """Try having consecutive buckets share as many partitions as possible. Start from a random bucket. Until there are buckets left, try to choose the next one so that it has as many partitions in common as possible with the previous one. When multiple options are available, pick one randomly. """ if nparts_lhs <= 0 or nparts_rhs <= 0: return [] # This is our "source of truth" on what buckets we haven't outputted yet. It # can be queried in constant time. remaining: Set[Bucket] = set() # These are our random orders: we shuffle them once and then pop from the # end. Each bucket appears in several of them. They are updated lazily, # which means they may contain buckets that have already been outputted. all_buckets: List[Bucket] = [] buckets_per_partition: List[List[Bucket]] = \ [[] for _ in range(max(nparts_lhs, nparts_rhs))] for lhs in range(nparts_lhs): for rhs in range(nparts_rhs): b = Bucket(Partition(lhs), Partition(rhs)) remaining.add(b) all_buckets.append(b) buckets_per_partition[lhs].append(b) buckets_per_partition[rhs].append(b) generator.shuffle(all_buckets) for buckets in buckets_per_partition: generator.shuffle(buckets) b = all_buckets.pop() remaining.remove(b) order = [b] while remaining: transposed_b = Bucket(b.rhs, b.lhs) if transposed_b in remaining: remaining.remove(transposed_b) order.append(transposed_b) if not remaining: break same_as_lhs = buckets_per_partition[b.lhs] same_as_rhs = buckets_per_partition[b.rhs] while len(same_as_lhs) > 0 or len(same_as_rhs) > 0: chosen, = generator.choices( [same_as_lhs, same_as_rhs], weights=[len(same_as_lhs), len(same_as_rhs)], ) next_b = chosen.pop() if next_b in remaining: break else: while True: next_b = all_buckets.pop() if next_b in remaining: break remaining.remove(next_b) order.append(next_b) b = next_b return order
def create_buckets_ordered_by_affinity( nparts_lhs: int, nparts_rhs: int, *, generator: random.Random, ) -> List[Bucket]: """Try having consecutive buckets share as many partitions as possible. Start from a random bucket. Until there are buckets left, try to choose the next one so that it has as many partitions in common as possible with the previous one. When multiple options are available, pick one randomly. """ if nparts_lhs <= 0 or nparts_rhs <= 0: return [] # TODO Change this function to use the same cost model as the LockServer # when computing affinity (based on the number of entities to save and load) # rather than just the number of partitions in common. Pay attention to keep # the complexity of this algorithm linear in the number of buckets. This # comment is too short to give a full description, but the idea is that only # a few transitions are possible between a bucket and the next: the one that # preserves all (ent, part) pairs, the one that preserves only the lhs ones, # only the rhs ones, only the intersection of the two, or none at all. So we # can keep a dict from sets of (ent, part) to lists of buckets, and insert # each bucket into four of those lists, namely the ones for all its (ent, # part), its lhs ones, its rhs ones and the intersection of its lhs and rhs # ones. Then, when looking for the next bucket, we figure out the transition # that is cheapest (among the options defined above), determine the set of # (ent, part) we need to move to in order to achieve that transition type # and we look up in the dict to find a bucket containing those (ent, part). # This is our "source of truth" on what buckets we haven't outputted yet. It # can be queried in constant time. remaining: Set[Bucket] = set() # These are our random orders: we shuffle them once and then pop from the # end. Each bucket appears in several of them. They are updated lazily, # which means they may contain buckets that have already been outputted. all_buckets: List[Bucket] = [] buckets_per_partition: List[List[Bucket]] = \ [[] for _ in range(max(nparts_lhs, nparts_rhs))] for lhs in range(nparts_lhs): for rhs in range(nparts_rhs): b = Bucket(Partition(lhs), Partition(rhs)) remaining.add(b) all_buckets.append(b) buckets_per_partition[lhs].append(b) buckets_per_partition[rhs].append(b) generator.shuffle(all_buckets) for buckets in buckets_per_partition: generator.shuffle(buckets) b = all_buckets.pop() remaining.remove(b) order = [b] while remaining: transposed_b = Bucket(b.rhs, b.lhs) if transposed_b in remaining: remaining.remove(transposed_b) order.append(transposed_b) if not remaining: break same_as_lhs = buckets_per_partition[b.lhs] same_as_rhs = buckets_per_partition[b.rhs] while len(same_as_lhs) > 0 or len(same_as_rhs) > 0: chosen, = generator.choices( [same_as_lhs, same_as_rhs], weights=[len(same_as_lhs), len(same_as_rhs)], ) next_b = chosen.pop() if next_b in remaining: break else: while True: next_b = all_buckets.pop() if next_b in remaining: break remaining.remove(next_b) order.append(next_b) b = next_b return order
def do_eval_and_report_stats( config: ConfigSchema, model: Optional[MultiRelationEmbedder] = None, evaluator: Optional[AbstractBatchProcessor] = None, ) -> Generator[Tuple[Optional[int], Optional[Bucket], Stats], None, None]: """Computes eval metrics (mr/mrr/r1/r10/r50) for a checkpoint with trained embeddings. """ if evaluator is None: evaluator = RankingEvaluator() if config.verbose > 0: import pprint pprint.PrettyPrinter().pprint(config.to_dict()) checkpoint_manager = CheckpointManager(config.checkpoint_path) def load_embeddings(entity: EntityName, part: Partition) -> torch.nn.Parameter: embs, _ = checkpoint_manager.read(entity, part) assert embs.is_shared() return torch.nn.Parameter(embs) nparts_lhs, lhs_partitioned_types = get_partitioned_types(config, Side.LHS) nparts_rhs, rhs_partitioned_types = get_partitioned_types(config, Side.RHS) num_workers = get_num_workers(config.workers) pool = create_pool(num_workers) if model is None: model = make_model(config) model.share_memory() state_dict, _ = checkpoint_manager.maybe_read_model() if state_dict is not None: model.load_state_dict(state_dict, strict=False) model.eval() for entity, econfig in config.entities.items(): if econfig.num_partitions == 1: embs = load_embeddings(entity, Partition(0)) model.set_embeddings(entity, embs, Side.LHS) model.set_embeddings(entity, embs, Side.RHS) all_stats: List[Stats] = [] for edge_path_idx, edge_path in enumerate(config.edge_paths): log("Starting edge path %d / %d (%s)" % (edge_path_idx + 1, len(config.edge_paths), edge_path)) edge_reader = EdgeReader(edge_path) all_edge_path_stats = [] last_lhs, last_rhs = None, None for bucket in create_buckets_ordered_lexicographically(nparts_lhs, nparts_rhs): tic = time.time() # log("%s: Loading entities" % (bucket,)) if last_lhs != bucket.lhs: for e in lhs_partitioned_types: model.clear_embeddings(e, Side.LHS) embs = load_embeddings(e, bucket.lhs) model.set_embeddings(e, embs, Side.LHS) if last_rhs != bucket.rhs: for e in rhs_partitioned_types: model.clear_embeddings(e, Side.RHS) embs = load_embeddings(e, bucket.rhs) model.set_embeddings(e, embs, Side.RHS) last_lhs, last_rhs = bucket.lhs, bucket.rhs # log("%s: Loading edges" % (bucket,)) edges = edge_reader.read(bucket.lhs, bucket.rhs) num_edges = len(edges) load_time = time.time() - tic tic = time.time() # log("%s: Launching and waiting for workers" % (bucket,)) all_bucket_stats = pool.map(call, [ partial( process_in_batches, batch_size=config.batch_size, model=model, batch_processor=evaluator, edges=edges[s], ) for s in split_almost_equally(num_edges, num_parts=num_workers) ]) compute_time = time.time() - tic log("%s: Processed %d edges in %.2g s (%.2gM/sec); load time: %.2g s" % (bucket, num_edges, compute_time, num_edges / compute_time / 1e6, load_time)) total_bucket_stats = Stats.sum(all_bucket_stats) all_edge_path_stats.append(total_bucket_stats) mean_bucket_stats = total_bucket_stats.average() log("Stats for edge path %d / %d, bucket %s: %s" % (edge_path_idx + 1, len(config.edge_paths), bucket, mean_bucket_stats)) yield edge_path_idx, bucket, mean_bucket_stats total_edge_path_stats = Stats.sum(all_edge_path_stats) all_stats.append(total_edge_path_stats) mean_edge_path_stats = total_edge_path_stats.average() log("") log("Stats for edge path %d / %d: %s" % (edge_path_idx + 1, len(config.edge_paths), mean_edge_path_stats)) log("") yield edge_path_idx, None, mean_edge_path_stats mean_stats = Stats.sum(all_stats).average() log("") log("Stats: %s" % mean_stats) log("") yield None, None, mean_stats pool.close() pool.join()
def _maybe_write_checkpoint( self, epoch_idx: int, edge_path_idx: int, edge_chunk_idx: int, ) -> None: config = self.config # Preserving a checkpoint requires two steps: # - create a snapshot (w/ symlinks) after it's first written; # - don't delete it once the following one is written. # These two happen in two successive iterations of the main loop: the # one just before and the one just after the epoch boundary. preserve_old_checkpoint = should_preserve_old_checkpoint( self.iteration_manager, config.checkpoint_preservation_interval) preserve_new_checkpoint = should_preserve_old_checkpoint( self.iteration_manager + 1, config.checkpoint_preservation_interval) # Write metadata: for multiple machines, write from rank-0 logger.info( f"Finished epoch {epoch_idx + 1} / {self.iteration_manager.num_epochs}, " f"edge path {edge_path_idx + 1} / {self.iteration_manager.num_edge_paths}, " f"edge chunk {edge_chunk_idx + 1} / {self.iteration_manager.num_edge_chunks}" ) if self.rank == 0: for entity, econfig in config.entities.items(): if econfig.num_partitions == 1: embs = self.holder.unpartitioned_embeddings[entity] optimizer = self.trainer.unpartitioned_optimizers[entity] self.checkpoint_manager.write( entity, Partition(0), embs.detach(), OptimizerStateDict(optimizer.state_dict())) logger.info("Writing the metadata") state_dict: ModuleStateDict = ModuleStateDict( self.model.state_dict()) self.checkpoint_manager.write_model( state_dict, OptimizerStateDict(self.trainer.model_optimizer.state_dict())) logger.info("Writing the training stats") all_stats_dicts: List[Dict[...]] = [] for stats in self.bucket_scheduler.get_stats_for_pass(): stats_dict = { "epoch_idx": epoch_idx, "edge_path_idx": edge_path_idx, "edge_chunk_idx": edge_chunk_idx, "lhs_partition": stats.lhs_partition, "rhs_partition": stats.rhs_partition, "index": stats.index, "stats": stats.train.to_dict(), } if stats.eval_before is not None: stats_dict[ "eval_stats_before"] = stats.eval_before.to_dict() if stats.eval_after is not None: stats_dict["eval_stats_after"] = stats.eval_after.to_dict() all_stats_dicts.append(stats_dict) self.checkpoint_manager.append_stats(all_stats_dicts) logger.info("Writing the checkpoint") self.checkpoint_manager.write_new_version( config, self.entity_counts, self.embedding_storage_freelist) dist_logger.info( "Waiting for other workers to write their parts of the checkpoint") self._barrier() dist_logger.info("All parts of the checkpoint have been written") logger.info("Switching to the new checkpoint version") self.checkpoint_manager.switch_to_new_version() dist_logger.info( "Waiting for other workers to switch to the new checkpoint version" ) self._barrier() dist_logger.info( "All workers have switched to the new checkpoint version") # After all the machines have finished committing # checkpoints, we either remove the old checkpoints # or we preserve it if preserve_new_checkpoint: # Add 1 so the index is a multiple of the interval, it looks nicer. self.checkpoint_manager.preserve_current_version( config, epoch_idx + 1) if not preserve_old_checkpoint: self.checkpoint_manager.remove_old_version(config)
def __init__( self, config: ConfigSchema, model: Optional[MultiRelationEmbedder] = None, trainer: Optional[AbstractBatchProcessor] = None, evaluator: Optional[AbstractBatchProcessor] = None, rank: Rank = RANK_ZERO, subprocess_init: Optional[Callable[[], None]] = None, ): """Each epoch/pass, for each partition pair, loads in embeddings and edgelist from disk, runs HOGWILD training on them, and writes partitions back to disk. """ tag_logs_with_process_name(f"Trainer-{rank}") self.config = config if config.verbose > 0: import pprint pprint.PrettyPrinter().pprint(config.to_dict()) logger.info("Loading entity counts...") entity_storage = ENTITY_STORAGES.make_instance(config.entity_path) entity_counts: Dict[str, List[int]] = {} for entity, econf in config.entities.items(): entity_counts[entity] = [] for part in range(econf.num_partitions): entity_counts[entity].append( entity_storage.load_count(entity, part)) # Figure out how many lhs and rhs partitions we need holder = self.holder = EmbeddingHolder(config) logger.debug( f"nparts {holder.nparts_lhs} {holder.nparts_rhs} " f"types {holder.lhs_partitioned_types} {holder.rhs_partitioned_types}" ) # We know ahead of time that we wil need 1-2 storages for each embedding type, # as well as the max size of this storage (num_entities x D). # We allocate these storages n advance in `embedding_storage_freelist`. # When we need storage for an entity type, we pop it from this free list, # and then add it back when we 'delete' the embedding table. embedding_storage_freelist: Dict[ EntityName, Set[torch.FloatStorage]] = defaultdict(set) for entity_type, counts in entity_counts.items(): max_count = max(counts) num_sides = ( (1 if entity_type in holder.lhs_partitioned_types else 0) + (1 if entity_type in holder.rhs_partitioned_types else 0) + (1 if entity_type in (holder.lhs_unpartitioned_types | holder.rhs_unpartitioned_types) else 0)) for _ in range(num_sides): embedding_storage_freelist[entity_type].add( allocate_shared_tensor((max_count, config.dimension), dtype=torch.float).storage()) # create the handlers, threads, etc. for distributed training if config.num_machines > 1 or config.num_partition_servers > 0: if not 0 <= rank < config.num_machines: raise RuntimeError("Invalid rank for trainer") if not td.is_available(): raise RuntimeError( "The installed PyTorch version doesn't provide " "distributed training capabilities.") ranks = ProcessRanks.from_num_invocations( config.num_machines, config.num_partition_servers) num_ps_groups = config.num_groups_for_partition_server groups: List[List[int]] = [ranks.trainers] # barrier group groups += [ranks.trainers + ranks.partition_servers ] * num_ps_groups # ps groups group_idxs_for_partition_servers = range(1, len(groups)) if rank == RANK_ZERO: logger.info("Setup lock server...") start_server( LockServer( num_clients=len(ranks.trainers), nparts_lhs=holder.nparts_lhs, nparts_rhs=holder.nparts_rhs, entities_lhs=holder.lhs_partitioned_types, entities_rhs=holder.rhs_partitioned_types, entity_counts=entity_counts, init_tree=config.distributed_tree_init_order, ), process_name="LockServer", init_method=config.distributed_init_method, world_size=ranks.world_size, server_rank=ranks.lock_server, groups=groups, subprocess_init=subprocess_init, ) self.bucket_scheduler = DistributedBucketScheduler( server_rank=ranks.lock_server, client_rank=ranks.trainers[rank], ) logger.info("Setup param server...") start_server( ParameterServer(num_clients=len(ranks.trainers)), process_name=f"ParamS-{rank}", init_method=config.distributed_init_method, world_size=ranks.world_size, server_rank=ranks.parameter_servers[rank], groups=groups, subprocess_init=subprocess_init, ) parameter_sharer = ParameterSharer( process_name=f"ParamC-{rank}", client_rank=ranks.parameter_clients[rank], all_server_ranks=ranks.parameter_servers, init_method=config.distributed_init_method, world_size=ranks.world_size, groups=groups, subprocess_init=subprocess_init, ) if config.num_partition_servers == -1: start_server( ParameterServer( num_clients=len(ranks.trainers), group_idxs=group_idxs_for_partition_servers, log_stats=True, ), process_name=f"PartS-{rank}", init_method=config.distributed_init_method, world_size=ranks.world_size, server_rank=ranks.partition_servers[rank], groups=groups, subprocess_init=subprocess_init, ) groups = init_process_group( rank=ranks.trainers[rank], world_size=ranks.world_size, init_method=config.distributed_init_method, groups=groups, ) trainer_group, *groups_for_partition_servers = groups self.barrier_group = trainer_group if len(ranks.partition_servers) > 0: partition_client = PartitionClient( ranks.partition_servers, groups=groups_for_partition_servers, log_stats=True, ) else: partition_client = None else: self.barrier_group = None self.bucket_scheduler = SingleMachineBucketScheduler( holder.nparts_lhs, holder.nparts_rhs, config.bucket_order) parameter_sharer = None partition_client = None hide_distributed_logging() # fork early for HOGWILD threads logger.info("Creating workers...") self.num_workers = get_num_workers(config.workers) self.pool = create_pool( self.num_workers, subprocess_name=f"TWorker-{rank}", subprocess_init=subprocess_init, ) checkpoint_manager = CheckpointManager( config.checkpoint_path, rank=rank, num_machines=config.num_machines, partition_client=partition_client, subprocess_name=f"BackgRW-{rank}", subprocess_init=subprocess_init, ) self.checkpoint_manager = checkpoint_manager checkpoint_manager.register_metadata_provider( ConfigMetadataProvider(config)) if rank == 0: checkpoint_manager.write_config(config) num_edge_chunks = get_num_edge_chunks(config) self.iteration_manager = IterationManager( config.num_epochs, config.edge_paths, num_edge_chunks, iteration_idx=checkpoint_manager.checkpoint_version) checkpoint_manager.register_metadata_provider(self.iteration_manager) logger.info("Initializing global model...") if model is None: model = make_model(config) model.share_memory() if trainer is None: trainer = Trainer( model_optimizer=make_optimizer(config, model.parameters(), False), loss_fn=config.loss_fn, margin=config.margin, relations=config.relations, ) if evaluator is None: evaluator = TrainingRankingEvaluator( override_num_batch_negs=config.eval_num_batch_negs, override_num_uniform_negs=config.eval_num_uniform_negs, ) if config.init_path is not None: self.loadpath_manager = CheckpointManager(config.init_path) else: self.loadpath_manager = None # load model from checkpoint or loadpath, if available state_dict, optim_state = checkpoint_manager.maybe_read_model() if state_dict is None and self.loadpath_manager is not None: state_dict, optim_state = self.loadpath_manager.maybe_read_model() if state_dict is not None: model.load_state_dict(state_dict, strict=False) if optim_state is not None: trainer.model_optimizer.load_state_dict(optim_state) logger.debug("Loading unpartitioned entities...") for entity in holder.lhs_unpartitioned_types | holder.rhs_unpartitioned_types: count = entity_counts[entity][0] s = embedding_storage_freelist[entity].pop() embs = torch.FloatTensor(s).view(-1, config.dimension)[:count] embs, optimizer = self._load_embeddings(entity, Partition(0), out=embs) holder.unpartitioned_embeddings[entity] = embs trainer.unpartitioned_optimizers[entity] = optimizer # start communicating shared parameters with the parameter server if parameter_sharer is not None: shared_parameters: Set[int] = set() for name, param in model.named_parameters(): if id(param) in shared_parameters: continue shared_parameters.add(id(param)) key = f"model.{name}" logger.info( f"Adding {key} ({param.numel()} params) to parameter server" ) parameter_sharer.set_param(key, param.data) for entity, embs in holder.unpartitioned_embeddings.items(): key = f"entity.{entity}" logger.info( f"Adding {key} ({embs.numel()} params) to parameter server" ) parameter_sharer.set_param(key, embs.data) # store everything in self self.model = model self.trainer = trainer self.evaluator = evaluator self.rank = rank self.entity_counts = entity_counts self.embedding_storage_freelist = embedding_storage_freelist self.strict = False
def do_eval_and_report_stats( config: ConfigSchema, model: Optional[MultiRelationEmbedder] = None, evaluator: Optional[AbstractBatchProcessor] = None, subprocess_init: Optional[Callable[[], None]] = None, ) -> Generator[Tuple[Optional[int], Optional[Bucket], Stats], None, None]: """Computes eval metrics (mr/mrr/r1/r10/r50) for a checkpoint with trained embeddings. """ tag_logs_with_process_name(f"Evaluator") if evaluator is None: evaluator = RankingEvaluator() if config.verbose > 0: import pprint pprint.PrettyPrinter().pprint(config.to_dict()) checkpoint_manager = CheckpointManager(config.checkpoint_path) def load_embeddings(entity: EntityName, part: Partition) -> torch.nn.Parameter: embs, _ = checkpoint_manager.read(entity, part) assert embs.is_shared() return torch.nn.Parameter(embs) nparts_lhs, lhs_partitioned_types = get_partitioned_types(config, Side.LHS) nparts_rhs, rhs_partitioned_types = get_partitioned_types(config, Side.RHS) num_workers = get_num_workers(config.workers) pool = create_pool( num_workers, subprocess_name="EvalWorker", subprocess_init=subprocess_init, ) if model is None: model = make_model(config) model.share_memory() state_dict, _ = checkpoint_manager.maybe_read_model() if state_dict is not None: model.load_state_dict(state_dict, strict=False) model.eval() for entity, econfig in config.entities.items(): if econfig.num_partitions == 1: embs = load_embeddings(entity, Partition(0)) model.set_embeddings(entity, embs, Side.LHS) model.set_embeddings(entity, embs, Side.RHS) all_stats: List[Stats] = [] for edge_path_idx, edge_path in enumerate(config.edge_paths): logger.info( f"Starting edge path {edge_path_idx + 1} / {len(config.edge_paths)} " f"({edge_path})") edge_storage = EDGE_STORAGES.make_instance(edge_path) all_edge_path_stats = [] last_lhs, last_rhs = None, None for bucket in create_buckets_ordered_lexicographically( nparts_lhs, nparts_rhs): tic = time.time() # logger.info(f"{bucket}: Loading entities") if last_lhs != bucket.lhs: for e in lhs_partitioned_types: model.clear_embeddings(e, Side.LHS) embs = load_embeddings(e, bucket.lhs) model.set_embeddings(e, embs, Side.LHS) if last_rhs != bucket.rhs: for e in rhs_partitioned_types: model.clear_embeddings(e, Side.RHS) embs = load_embeddings(e, bucket.rhs) model.set_embeddings(e, embs, Side.RHS) last_lhs, last_rhs = bucket.lhs, bucket.rhs # logger.info(f"{bucket}: Loading edges") edges = edge_storage.load_edges(bucket.lhs, bucket.rhs) num_edges = len(edges) load_time = time.time() - tic tic = time.time() # logger.info(f"{bucket}: Launching and waiting for workers") future_all_bucket_stats = pool.map_async(call, [ partial( process_in_batches, batch_size=config.batch_size, model=model, batch_processor=evaluator, edges=edges[s], ) for s in split_almost_equally(num_edges, num_parts=num_workers) ]) all_bucket_stats = \ get_async_result(future_all_bucket_stats, pool) compute_time = time.time() - tic logger.info( f"{bucket}: Processed {num_edges} edges in {compute_time:.2g} s " f"({num_edges / compute_time / 1e6:.2g}M/sec); " f"load time: {load_time:.2g} s") total_bucket_stats = Stats.sum(all_bucket_stats) all_edge_path_stats.append(total_bucket_stats) mean_bucket_stats = total_bucket_stats.average() logger.info( f"Stats for edge path {edge_path_idx + 1} / {len(config.edge_paths)}, " f"bucket {bucket}: {mean_bucket_stats}") yield edge_path_idx, bucket, mean_bucket_stats total_edge_path_stats = Stats.sum(all_edge_path_stats) all_stats.append(total_edge_path_stats) mean_edge_path_stats = total_edge_path_stats.average() logger.info("") logger.info( f"Stats for edge path {edge_path_idx + 1} / {len(config.edge_paths)}: " f"{mean_edge_path_stats}") logger.info("") yield edge_path_idx, None, mean_edge_path_stats mean_stats = Stats.sum(all_stats).average() logger.info("") logger.info(f"Stats: {mean_stats}") logger.info("") yield None, None, mean_stats pool.close() pool.join()
def train_and_report_stats( config: ConfigSchema, model: Optional[MultiRelationEmbedder] = None, trainer: Optional[AbstractBatchProcessor] = None, evaluator: Optional[AbstractBatchProcessor] = None, rank: Rank = RANK_ZERO, subprocess_init: Optional[Callable[[], None]] = None, ) -> Generator[Tuple[int, Optional[Stats], Stats, Optional[Stats]], None, None]: """Each epoch/pass, for each partition pair, loads in embeddings and edgelist from disk, runs HOGWILD training on them, and writes partitions back to disk. """ tag_logs_with_process_name(f"Trainer-{rank}") if config.verbose > 0: import pprint pprint.PrettyPrinter().pprint(config.to_dict()) logger.info("Loading entity counts...") entity_storage = ENTITY_STORAGES.make_instance(config.entity_path) entity_counts: Dict[str, List[int]] = {} for entity, econf in config.entities.items(): entity_counts[entity] = [] for part in range(econf.num_partitions): entity_counts[entity].append( entity_storage.load_count(entity, part)) # Figure out how many lhs and rhs partitions we need nparts_lhs, lhs_partitioned_types = get_partitioned_types(config, Side.LHS) nparts_rhs, rhs_partitioned_types = get_partitioned_types(config, Side.RHS) logger.debug(f"nparts {nparts_lhs} {nparts_rhs} " f"types {lhs_partitioned_types} {rhs_partitioned_types}") total_buckets = nparts_lhs * nparts_rhs sync: AbstractSynchronizer bucket_scheduler: AbstractBucketScheduler parameter_sharer: Optional[ParameterSharer] partition_client: Optional[PartitionClient] if config.num_machines > 1: if not 0 <= rank < config.num_machines: raise RuntimeError("Invalid rank for trainer") if not td.is_available(): raise RuntimeError("The installed PyTorch version doesn't provide " "distributed training capabilities.") ranks = ProcessRanks.from_num_invocations(config.num_machines, config.num_partition_servers) if rank == RANK_ZERO: logger.info("Setup lock server...") start_server( LockServer( num_clients=len(ranks.trainers), nparts_lhs=nparts_lhs, nparts_rhs=nparts_rhs, lock_lhs=len(lhs_partitioned_types) > 0, lock_rhs=len(rhs_partitioned_types) > 0, init_tree=config.distributed_tree_init_order, ), process_name="LockServer", init_method=config.distributed_init_method, world_size=ranks.world_size, server_rank=ranks.lock_server, groups=[ranks.trainers], subprocess_init=subprocess_init, ) bucket_scheduler = DistributedBucketScheduler( server_rank=ranks.lock_server, client_rank=ranks.trainers[rank], ) logger.info("Setup param server...") start_server( ParameterServer(num_clients=len(ranks.trainers)), process_name=f"ParamS-{rank}", init_method=config.distributed_init_method, world_size=ranks.world_size, server_rank=ranks.parameter_servers[rank], groups=[ranks.trainers], subprocess_init=subprocess_init, ) parameter_sharer = ParameterSharer( process_name=f"ParamC-{rank}", client_rank=ranks.parameter_clients[rank], all_server_ranks=ranks.parameter_servers, init_method=config.distributed_init_method, world_size=ranks.world_size, groups=[ranks.trainers], subprocess_init=subprocess_init, ) if config.num_partition_servers == -1: start_server( ParameterServer(num_clients=len(ranks.trainers), log_stats=True), process_name=f"PartS-{rank}", init_method=config.distributed_init_method, world_size=ranks.world_size, server_rank=ranks.partition_servers[rank], groups=[ranks.trainers], subprocess_init=subprocess_init, ) if len(ranks.partition_servers) > 0: partition_client = PartitionClient(ranks.partition_servers, log_stats=True) else: partition_client = None groups = init_process_group( rank=ranks.trainers[rank], world_size=ranks.world_size, init_method=config.distributed_init_method, groups=[ranks.trainers], ) trainer_group, = groups sync = DistributedSynchronizer(trainer_group) else: sync = DummySynchronizer() bucket_scheduler = SingleMachineBucketScheduler( nparts_lhs, nparts_rhs, config.bucket_order) parameter_sharer = None partition_client = None hide_distributed_logging() # fork early for HOGWILD threads logger.info("Creating workers...") num_workers = get_num_workers(config.workers) pool = create_pool( num_workers, subprocess_name=f"TWorker-{rank}", subprocess_init=subprocess_init, ) def make_optimizer(params: Iterable[torch.nn.Parameter], is_emb: bool) -> Optimizer: params = list(params) if len(params) == 0: optimizer = DummyOptimizer() elif is_emb: optimizer = RowAdagrad(params, lr=config.lr) else: if config.relation_lr is not None: lr = config.relation_lr else: lr = config.lr optimizer = Adagrad(params, lr=lr) optimizer.share_memory() return optimizer # background_io is only supported in single-machine mode background_io = config.background_io and config.num_machines == 1 checkpoint_manager = CheckpointManager( config.checkpoint_path, background=background_io, rank=rank, num_machines=config.num_machines, partition_client=partition_client, subprocess_name=f"BackgRW-{rank}", subprocess_init=subprocess_init, ) checkpoint_manager.register_metadata_provider( ConfigMetadataProvider(config)) checkpoint_manager.write_config(config) if config.num_edge_chunks is not None: num_edge_chunks = config.num_edge_chunks else: num_edge_chunks = get_num_edge_chunks(config.edge_paths, nparts_lhs, nparts_rhs, config.max_edges_per_chunk) iteration_manager = IterationManager( config.num_epochs, config.edge_paths, num_edge_chunks, iteration_idx=checkpoint_manager.checkpoint_version) checkpoint_manager.register_metadata_provider(iteration_manager) if config.init_path is not None: loadpath_manager = CheckpointManager(config.init_path) else: loadpath_manager = None def load_embeddings( entity: EntityName, part: Partition, strict: bool = False, force_dirty: bool = False, ) -> Tuple[torch.nn.Parameter, Optional[OptimizerStateDict]]: if strict: embs, optim_state = checkpoint_manager.read( entity, part, force_dirty=force_dirty) else: # Strict is only false during the first iteration, because in that # case the checkpoint may not contain any data (unless a previous # run was resumed) so we fall back on initial values. embs, optim_state = checkpoint_manager.maybe_read( entity, part, force_dirty=force_dirty) if embs is None and loadpath_manager is not None: embs, optim_state = loadpath_manager.maybe_read(entity, part) if embs is None: embs, optim_state = init_embs(entity, entity_counts[entity][part], config.dimension, config.init_scale) assert embs.is_shared() return torch.nn.Parameter(embs), optim_state logger.info("Initializing global model...") if model is None: model = make_model(config) model.share_memory() if trainer is None: trainer = Trainer( global_optimizer=make_optimizer(model.parameters(), False), loss_fn=config.loss_fn, margin=config.margin, relations=config.relations, ) if evaluator is None: evaluator = TrainingRankingEvaluator( override_num_batch_negs=config.eval_num_batch_negs, override_num_uniform_negs=config.eval_num_uniform_negs, ) eval_batch_size = round_up_to_nearest_multiple(config.batch_size, config.eval_num_batch_negs) state_dict, optim_state = checkpoint_manager.maybe_read_model() if state_dict is None and loadpath_manager is not None: state_dict, optim_state = loadpath_manager.maybe_read_model() if state_dict is not None: model.load_state_dict(state_dict, strict=False) if optim_state is not None: trainer.global_optimizer.load_state_dict(optim_state) logger.debug("Loading unpartitioned entities...") for entity, econfig in config.entities.items(): if econfig.num_partitions == 1: embs, optim_state = load_embeddings(entity, Partition(0)) model.set_embeddings(entity, embs, Side.LHS) model.set_embeddings(entity, embs, Side.RHS) optimizer = make_optimizer([embs], True) if optim_state is not None: optimizer.load_state_dict(optim_state) trainer.entity_optimizers[(entity, Partition(0))] = optimizer # start communicating shared parameters with the parameter server if parameter_sharer is not None: parameter_sharer.share_model_params(model) strict = False def swap_partitioned_embeddings( old_b: Optional[Bucket], new_b: Optional[Bucket], ): # 0. given the old and new buckets, construct data structures to keep # track of old and new embedding (entity, part) tuples io_bytes = 0 logger.info(f"Swapping partitioned embeddings {old_b} {new_b}") types = ([(e, Side.LHS) for e in lhs_partitioned_types] + [(e, Side.RHS) for e in rhs_partitioned_types]) old_parts = {(e, old_b.get_partition(side)): side for e, side in types if old_b is not None} new_parts = {(e, new_b.get_partition(side)): side for e, side in types if new_b is not None} to_checkpoint = set(old_parts) - set(new_parts) preserved = set(old_parts) & set(new_parts) # 1. checkpoint embeddings that will not be used in the next pair # if old_b is not None: # there are previous embeddings to checkpoint logger.info("Writing partitioned embeddings") for entity, part in to_checkpoint: side = old_parts[(entity, part)] side_name = side.pick("lhs", "rhs") logger.debug(f"Checkpointing ({entity} {part} {side_name})") embs = model.get_embeddings(entity, side) optim_key = (entity, part) optim_state = OptimizerStateDict( trainer.entity_optimizers[optim_key].state_dict()) io_bytes += embs.numel() * embs.element_size( ) # ignore optim state checkpoint_manager.write(entity, part, embs.detach(), optim_state) if optim_key in trainer.entity_optimizers: del trainer.entity_optimizers[optim_key] # these variables are holding large objects; let them be freed del embs del optim_state bucket_scheduler.release_bucket(old_b) # 2. copy old embeddings that will be used in the next pair # into a temporary dictionary # tmp_emb = { x: model.get_embeddings(x[0], old_parts[x]) for x in preserved } for entity, _ in types: model.clear_embeddings(entity, Side.LHS) model.clear_embeddings(entity, Side.RHS) if new_b is None: # there are no new embeddings to load return io_bytes bucket_logger = BucketLogger(logger, bucket=new_b) # 3. load new embeddings into the model/optimizer, either from disk # or the temporary dictionary # bucket_logger.info("Loading entities") for entity, side in types: part = new_b.get_partition(side) part_key = (entity, part) if part_key in tmp_emb: bucket_logger.debug( f"Loading ({entity}, {part}) from preserved") embs, optim_state = tmp_emb[part_key], None else: bucket_logger.debug(f"Loading ({entity}, {part})") force_dirty = bucket_scheduler.check_and_set_dirty( entity, part) embs, optim_state = load_embeddings(entity, part, strict=strict, force_dirty=force_dirty) io_bytes += embs.numel() * embs.element_size( ) # ignore optim state model.set_embeddings(entity, embs, side) tmp_emb[part_key] = embs optim_key = (entity, part) if optim_key not in trainer.entity_optimizers: bucket_logger.debug(f"Resetting optimizer {optim_key}") optimizer = make_optimizer([embs], True) if optim_state is not None: bucket_logger.debug("Setting optim state") optimizer.load_state_dict(optim_state) trainer.entity_optimizers[optim_key] = optimizer return io_bytes if rank == RANK_ZERO: for stats_dict in checkpoint_manager.maybe_read_stats(): index: int = stats_dict["index"] stats: Stats = Stats.from_dict(stats_dict["stats"]) eval_stats_before: Optional[Stats] = None if "eval_stats_before" in stats_dict: eval_stats_before = Stats.from_dict( stats_dict["eval_stats_before"]) eval_stats_after: Optional[Stats] = None if "eval_stats_after" in stats_dict: eval_stats_after = Stats.from_dict( stats_dict["eval_stats_after"]) yield (index, eval_stats_before, stats, eval_stats_after) # Start of the main training loop. for epoch_idx, edge_path_idx, edge_chunk_idx in iteration_manager: logger.info( f"Starting epoch {epoch_idx + 1} / {iteration_manager.num_epochs}, " f"edge path {edge_path_idx + 1} / {iteration_manager.num_edge_paths}, " f"edge chunk {edge_chunk_idx + 1} / {iteration_manager.num_edge_chunks}" ) edge_storage = EDGE_STORAGES.make_instance(iteration_manager.edge_path) logger.info(f"Edge path: {iteration_manager.edge_path}") sync.barrier() dist_logger.info("Lock client new epoch...") bucket_scheduler.new_pass( is_first=iteration_manager.iteration_idx == 0) sync.barrier() remaining = total_buckets cur_b = None while remaining > 0: old_b = cur_b io_time = 0. io_bytes = 0 cur_b, remaining = bucket_scheduler.acquire_bucket() logger.info(f"still in queue: {remaining}") if cur_b is None: if old_b is not None: # if you couldn't get a new pair, release the lock # to prevent a deadlock! tic = time.time() io_bytes += swap_partitioned_embeddings(old_b, None) io_time += time.time() - tic time.sleep(1) # don't hammer td continue bucket_logger = BucketLogger(logger, bucket=cur_b) tic = time.time() io_bytes += swap_partitioned_embeddings(old_b, cur_b) current_index = \ (iteration_manager.iteration_idx + 1) * total_buckets - remaining next_b = bucket_scheduler.peek() if next_b is not None and background_io: # Ensure the previous bucket finished writing to disk. checkpoint_manager.wait_for_marker(current_index - 1) bucket_logger.debug("Prefetching") for entity in lhs_partitioned_types: checkpoint_manager.prefetch(entity, next_b.lhs) for entity in rhs_partitioned_types: checkpoint_manager.prefetch(entity, next_b.rhs) checkpoint_manager.record_marker(current_index) bucket_logger.debug("Loading edges") edges = edge_storage.load_chunk_of_edges( cur_b.lhs, cur_b.rhs, edge_chunk_idx, iteration_manager.num_edge_chunks) num_edges = len(edges) # this might be off in the case of tensorlist or extra edge fields io_bytes += edges.lhs.tensor.numel( ) * edges.lhs.tensor.element_size() io_bytes += edges.rhs.tensor.numel( ) * edges.rhs.tensor.element_size() io_bytes += edges.rel.numel() * edges.rel.element_size() bucket_logger.debug("Shuffling edges") # Fix a seed to get the same permutation every time; have it # depend on all and only what affects the set of edges. g = torch.Generator() g.manual_seed( hash((edge_path_idx, edge_chunk_idx, cur_b.lhs, cur_b.rhs))) num_eval_edges = int(num_edges * config.eval_fraction) if num_eval_edges > 0: edge_perm = torch.randperm(num_edges, generator=g) eval_edge_perm = edge_perm[-num_eval_edges:] num_edges -= num_eval_edges edge_perm = edge_perm[torch.randperm(num_edges)] else: edge_perm = torch.randperm(num_edges) # HOGWILD evaluation before training eval_stats_before: Optional[Stats] = None if num_eval_edges > 0: bucket_logger.debug( "Waiting for workers to perform evaluation") future_all_eval_stats_before = pool.map_async( call, [ partial( process_in_batches, batch_size=eval_batch_size, model=model, batch_processor=evaluator, edges=edges, indices=eval_edge_perm[s], ) for s in split_almost_equally(eval_edge_perm.size(0), num_parts=num_workers) ]) all_eval_stats_before = \ get_async_result(future_all_eval_stats_before, pool) eval_stats_before = Stats.sum(all_eval_stats_before).average() bucket_logger.info( f"Stats before training: {eval_stats_before}") io_time += time.time() - tic tic = time.time() # HOGWILD training bucket_logger.debug("Waiting for workers to perform training") # FIXME should we only delay if iteration_idx == 0? future_all_stats = pool.map_async(call, [ partial( process_in_batches, batch_size=config.batch_size, model=model, batch_processor=trainer, edges=edges, indices=edge_perm[s], delay=config.hogwild_delay if epoch_idx == 0 and rank > 0 else 0, ) for rank, s in enumerate( split_almost_equally(edge_perm.size(0), num_parts=num_workers)) ]) all_stats = get_async_result(future_all_stats, pool) stats = Stats.sum(all_stats).average() compute_time = time.time() - tic bucket_logger.info( f"bucket {total_buckets - remaining} / {total_buckets} : " f"Processed {num_edges} edges in {compute_time:.2f} s " f"( {num_edges / compute_time / 1e6:.2g} M/sec ); " f"io: {io_time:.2f} s ( {io_bytes / io_time / 1e6:.2f} MB/sec )" ) bucket_logger.info(f"{stats}") # HOGWILD eval after training eval_stats_after: Optional[Stats] = None if num_eval_edges > 0: bucket_logger.debug( "Waiting for workers to perform evaluation") future_all_eval_stats_after = pool.map_async( call, [ partial( process_in_batches, batch_size=eval_batch_size, model=model, batch_processor=evaluator, edges=edges, indices=eval_edge_perm[s], ) for s in split_almost_equally(eval_edge_perm.size(0), num_parts=num_workers) ]) all_eval_stats_after = \ get_async_result(future_all_eval_stats_after, pool) eval_stats_after = Stats.sum(all_eval_stats_after).average() bucket_logger.info(f"Stats after training: {eval_stats_after}") # Add train/eval metrics to queue stats_dict = { "index": current_index, "stats": stats.to_dict(), } if eval_stats_before is not None: stats_dict["eval_stats_before"] = eval_stats_before.to_dict() if eval_stats_after is not None: stats_dict["eval_stats_after"] = eval_stats_after.to_dict() checkpoint_manager.append_stats(stats_dict) yield current_index, eval_stats_before, stats, eval_stats_after swap_partitioned_embeddings(cur_b, None) # Distributed Processing: all machines can leave the barrier now. sync.barrier() # Preserving a checkpoint requires two steps: # - create a snapshot (w/ symlinks) after it's first written; # - don't delete it once the following one is written. # These two happen in two successive iterations of the main loop: the # one just before and the one just after the epoch boundary. preserve_old_checkpoint = should_preserve_old_checkpoint( iteration_manager, config.checkpoint_preservation_interval) preserve_new_checkpoint = should_preserve_old_checkpoint( iteration_manager + 1, config.checkpoint_preservation_interval) # Write metadata: for multiple machines, write from rank-0 logger.info( f"Finished epoch {epoch_idx + 1} / {iteration_manager.num_epochs}, " f"edge path {edge_path_idx + 1} / {iteration_manager.num_edge_paths}, " f"edge chunk {edge_chunk_idx + 1} / {iteration_manager.num_edge_chunks}" ) if rank == 0: for entity, econfig in config.entities.items(): if econfig.num_partitions == 1: embs = model.get_embeddings(entity, Side.LHS) optimizer = trainer.entity_optimizers[(entity, Partition(0))] checkpoint_manager.write( entity, Partition(0), embs.detach(), OptimizerStateDict(optimizer.state_dict())) sanitized_state_dict: ModuleStateDict = {} for k, v in ModuleStateDict(model.state_dict()).items(): if k.startswith('lhs_embs') or k.startswith('rhs_embs'): # skipping state that's an entity embedding continue sanitized_state_dict[k] = v logger.info("Writing the metadata") checkpoint_manager.write_model( sanitized_state_dict, OptimizerStateDict(trainer.global_optimizer.state_dict()), ) logger.info("Writing the checkpoint") checkpoint_manager.write_new_version(config) dist_logger.info( "Waiting for other workers to write their parts of the checkpoint") sync.barrier() dist_logger.info("All parts of the checkpoint have been written") logger.info("Switching to the new checkpoint version") checkpoint_manager.switch_to_new_version() dist_logger.info( "Waiting for other workers to switch to the new checkpoint version" ) sync.barrier() dist_logger.info( "All workers have switched to the new checkpoint version") # After all the machines have finished committing # checkpoints, we either remove the old checkpoints # or we preserve it if preserve_new_checkpoint: # Add 1 so the index is a multiple of the interval, it looks nicer. checkpoint_manager.preserve_current_version(config, epoch_idx + 1) if not preserve_old_checkpoint: checkpoint_manager.remove_old_version(config) # now we're sure that all partition files exist, # so be strict about loading them strict = True # quiescence pool.close() pool.join() sync.barrier() checkpoint_manager.close() if loadpath_manager is not None: loadpath_manager.close() # FIXME join distributed workers (not really necessary) logger.info("Exiting")
def train_and_report_stats( config: ConfigSchema, model: Optional[MultiRelationEmbedder] = None, trainer: Optional[AbstractBatchProcessor] = None, evaluator: Optional[AbstractBatchProcessor] = None, rank: Rank = RANK_ZERO, ) -> Generator[Tuple[int, Optional[Stats], Stats, Optional[Stats]], None, None]: """Each epoch/pass, for each partition pair, loads in embeddings and edgelist from disk, runs HOGWILD training on them, and writes partitions back to disk. """ if config.verbose > 0: import pprint pprint.PrettyPrinter().pprint(config.to_dict()) log("Loading entity counts...") if maybe_old_entity_path(config.entity_path): log("WARNING: It may be that your entity path contains files using the " "old format. See D14241362 for how to update them.") entity_counts: Dict[str, List[int]] = {} for entity, econf in config.entities.items(): entity_counts[entity] = [] for part in range(econf.num_partitions): with open(os.path.join( config.entity_path, "entity_count_%s_%d.txt" % (entity, part) ), "rt") as tf: entity_counts[entity].append(int(tf.read().strip())) # Figure out how many lhs and rhs partitions we need nparts_lhs, lhs_partitioned_types = get_partitioned_types(config, Side.LHS) nparts_rhs, rhs_partitioned_types = get_partitioned_types(config, Side.RHS) vlog("nparts %d %d types %s %s" % (nparts_lhs, nparts_rhs, lhs_partitioned_types, rhs_partitioned_types)) total_buckets = nparts_lhs * nparts_rhs sync: AbstractSynchronizer bucket_scheduler: AbstractBucketScheduler parameter_sharer: Optional[ParameterSharer] partition_client: Optional[PartitionClient] if config.num_machines > 1: if not 0 <= rank < config.num_machines: raise RuntimeError("Invalid rank for trainer") if not td.is_available(): raise RuntimeError("The installed PyTorch version doesn't provide " "distributed training capabilities.") ranks = ProcessRanks.from_num_invocations( config.num_machines, config.num_partition_servers) if rank == RANK_ZERO: log("Setup lock server...") start_server( LockServer( num_clients=len(ranks.trainers), nparts_lhs=nparts_lhs, nparts_rhs=nparts_rhs, lock_lhs=len(lhs_partitioned_types) > 0, lock_rhs=len(rhs_partitioned_types) > 0, init_tree=config.distributed_tree_init_order, ), server_rank=ranks.lock_server, world_size=ranks.world_size, init_method=config.distributed_init_method, groups=[ranks.trainers], ) bucket_scheduler = DistributedBucketScheduler( server_rank=ranks.lock_server, client_rank=ranks.trainers[rank], ) log("Setup param server...") start_server( ParameterServer(num_clients=len(ranks.trainers)), server_rank=ranks.parameter_servers[rank], init_method=config.distributed_init_method, world_size=ranks.world_size, groups=[ranks.trainers], ) parameter_sharer = ParameterSharer( client_rank=ranks.parameter_clients[rank], all_server_ranks=ranks.parameter_servers, init_method=config.distributed_init_method, world_size=ranks.world_size, groups=[ranks.trainers], ) if config.num_partition_servers == -1: start_server( ParameterServer(num_clients=len(ranks.trainers)), server_rank=ranks.partition_servers[rank], world_size=ranks.world_size, init_method=config.distributed_init_method, groups=[ranks.trainers], ) if len(ranks.partition_servers) > 0: partition_client = PartitionClient(ranks.partition_servers) else: partition_client = None groups = init_process_group( rank=ranks.trainers[rank], world_size=ranks.world_size, init_method=config.distributed_init_method, groups=[ranks.trainers], ) trainer_group, = groups sync = DistributedSynchronizer(trainer_group) dlog = log else: sync = DummySynchronizer() bucket_scheduler = SingleMachineBucketScheduler( nparts_lhs, nparts_rhs, config.bucket_order) parameter_sharer = None partition_client = None dlog = lambda msg: None # fork early for HOGWILD threads log("Creating workers...") num_workers = get_num_workers(config.workers) pool = create_pool(num_workers) def make_optimizer(params: Iterable[torch.nn.Parameter], is_emb: bool) -> Optimizer: params = list(params) if len(params) == 0: optimizer = DummyOptimizer() elif is_emb: optimizer = RowAdagrad(params, lr=config.lr) else: if config.relation_lr is not None: lr = config.relation_lr else: lr = config.lr optimizer = Adagrad(params, lr=lr) optimizer.share_memory() return optimizer # background_io is only supported in single-machine mode background_io = config.background_io and config.num_machines == 1 checkpoint_manager = CheckpointManager( config.checkpoint_path, background=background_io, rank=rank, num_machines=config.num_machines, partition_client=partition_client, ) checkpoint_manager.register_metadata_provider(ConfigMetadataProvider(config)) checkpoint_manager.write_config(config) iteration_manager = IterationManager( config.num_epochs, config.edge_paths, config.num_edge_chunks, iteration_idx=checkpoint_manager.checkpoint_version) checkpoint_manager.register_metadata_provider(iteration_manager) if config.init_path is not None: loadpath_manager = CheckpointManager(config.init_path) else: loadpath_manager = None def load_embeddings( entity: EntityName, part: Partition, strict: bool = False, force_dirty: bool = False, ) -> Tuple[torch.nn.Parameter, Optional[OptimizerStateDict]]: if strict: embs, optim_state = checkpoint_manager.read(entity, part, force_dirty=force_dirty) else: # Strict is only false during the first iteration, because in that # case the checkpoint may not contain any data (unless a previous # run was resumed) so we fall back on initial values. embs, optim_state = checkpoint_manager.maybe_read(entity, part, force_dirty=force_dirty) if embs is None and loadpath_manager is not None: embs, optim_state = loadpath_manager.maybe_read(entity, part) if embs is None: embs, optim_state = init_embs(entity, entity_counts[entity][part], config.dimension, config.init_scale) assert embs.is_shared() return torch.nn.Parameter(embs), optim_state log("Initializing global model...") if model is None: model = make_model(config) model.share_memory() if trainer is None: trainer = Trainer( global_optimizer=make_optimizer(model.parameters(), False), loss_fn=config.loss_fn, margin=config.margin, relations=config.relations, ) if evaluator is None: evaluator = TrainingRankingEvaluator( override_num_batch_negs=config.eval_num_batch_negs, override_num_uniform_negs=config.eval_num_uniform_negs, ) eval_batch_size = round_up_to_nearest_multiple(config.batch_size, config.eval_num_batch_negs) state_dict, optim_state = checkpoint_manager.maybe_read_model() if state_dict is None and loadpath_manager is not None: state_dict, optim_state = loadpath_manager.maybe_read_model() if state_dict is not None: model.load_state_dict(state_dict, strict=False) if optim_state is not None: trainer.global_optimizer.load_state_dict(optim_state) vlog("Loading unpartitioned entities...") for entity, econfig in config.entities.items(): if econfig.num_partitions == 1: embs, optim_state = load_embeddings(entity, Partition(0)) model.set_embeddings(entity, embs, Side.LHS) model.set_embeddings(entity, embs, Side.RHS) optimizer = make_optimizer([embs], True) if optim_state is not None: optimizer.load_state_dict(optim_state) trainer.entity_optimizers[(entity, Partition(0))] = optimizer # start communicating shared parameters with the parameter server if parameter_sharer is not None: parameter_sharer.share_model_params(model) strict = False def swap_partitioned_embeddings( old_b: Optional[Bucket], new_b: Optional[Bucket], ): # 0. given the old and new buckets, construct data structures to keep # track of old and new embedding (entity, part) tuples io_bytes = 0 log("Swapping partitioned embeddings %s %s" % (old_b, new_b)) types = ([(e, Side.LHS) for e in lhs_partitioned_types] + [(e, Side.RHS) for e in rhs_partitioned_types]) old_parts = {(e, old_b.get_partition(side)): side for e, side in types if old_b is not None} new_parts = {(e, new_b.get_partition(side)): side for e, side in types if new_b is not None} to_checkpoint = set(old_parts) - set(new_parts) preserved = set(old_parts) & set(new_parts) # 1. checkpoint embeddings that will not be used in the next pair # if old_b is not None: # there are previous embeddings to checkpoint log("Writing partitioned embeddings") for entity, part in to_checkpoint: side = old_parts[(entity, part)] vlog("Checkpointing (%s %d %s)" % (entity, part, side.pick("lhs", "rhs"))) embs = model.get_embeddings(entity, side) optim_key = (entity, part) optim_state = OptimizerStateDict(trainer.entity_optimizers[optim_key].state_dict()) io_bytes += embs.numel() * embs.element_size() # ignore optim state checkpoint_manager.write(entity, part, embs.detach(), optim_state) if optim_key in trainer.entity_optimizers: del trainer.entity_optimizers[optim_key] # these variables are holding large objects; let them be freed del embs del optim_state bucket_scheduler.release_bucket(old_b) # 2. copy old embeddings that will be used in the next pair # into a temporary dictionary # tmp_emb = {x: model.get_embeddings(x[0], old_parts[x]) for x in preserved} for entity, _ in types: model.clear_embeddings(entity, Side.LHS) model.clear_embeddings(entity, Side.RHS) if new_b is None: # there are no new embeddings to load return io_bytes # 3. load new embeddings into the model/optimizer, either from disk # or the temporary dictionary # log("Loading entities") for entity, side in types: part = new_b.get_partition(side) part_key = (entity, part) if part_key in tmp_emb: vlog("Loading (%s, %d) from preserved" % (entity, part)) embs, optim_state = tmp_emb[part_key], None else: vlog("Loading (%s, %d)" % (entity, part)) force_dirty = bucket_scheduler.check_and_set_dirty(entity, part) embs, optim_state = load_embeddings( entity, part, strict=strict, force_dirty=force_dirty) io_bytes += embs.numel() * embs.element_size() # ignore optim state model.set_embeddings(entity, embs, side) tmp_emb[part_key] = embs optim_key = (entity, part) if optim_key not in trainer.entity_optimizers: vlog("Resetting optimizer %s" % (optim_key,)) optimizer = make_optimizer([embs], True) if optim_state is not None: vlog("Setting optim state") optimizer.load_state_dict(optim_state) trainer.entity_optimizers[optim_key] = optimizer return io_bytes # Start of the main training loop. for epoch_idx, edge_path_idx, edge_chunk_idx \ in iteration_manager.remaining_iterations(): log("Starting epoch %d / %d edge path %d / %d edge chunk %d / %d" % (epoch_idx + 1, iteration_manager.num_epochs, edge_path_idx + 1, iteration_manager.num_edge_paths, edge_chunk_idx + 1, iteration_manager.num_edge_chunks)) edge_reader = EdgeReader(iteration_manager.edge_path) log("edge_path= %s" % iteration_manager.edge_path) sync.barrier() dlog("Lock client new epoch...") bucket_scheduler.new_pass(is_first=iteration_manager.iteration_idx == 0) sync.barrier() remaining = total_buckets cur_b = None while remaining > 0: old_b = cur_b io_time = 0. io_bytes = 0 cur_b, remaining = bucket_scheduler.acquire_bucket() print('still in queue: %d' % remaining, file=sys.stderr) if cur_b is None: if old_b is not None: # if you couldn't get a new pair, release the lock # to prevent a deadlock! tic = time.time() io_bytes += swap_partitioned_embeddings(old_b, None) io_time += time.time() - tic time.sleep(1) # don't hammer td continue def log_status(msg, always=False): f = log if always else vlog f("%s: %s" % (cur_b, msg)) tic = time.time() io_bytes += swap_partitioned_embeddings(old_b, cur_b) current_index = \ (iteration_manager.iteration_idx + 1) * total_buckets - remaining next_b = bucket_scheduler.peek() if next_b is not None and background_io: # Ensure the previous bucket finished writing to disk. checkpoint_manager.wait_for_marker(current_index - 1) log_status("Prefetching") for entity in lhs_partitioned_types: checkpoint_manager.prefetch(entity, next_b.lhs) for entity in rhs_partitioned_types: checkpoint_manager.prefetch(entity, next_b.rhs) checkpoint_manager.record_marker(current_index) log_status("Loading edges") edges = edge_reader.read( cur_b.lhs, cur_b.rhs, edge_chunk_idx, config.num_edge_chunks) num_edges = len(edges) # this might be off in the case of tensorlist or extra edge fields io_bytes += edges.lhs.tensor.numel() * edges.lhs.tensor.element_size() io_bytes += edges.rhs.tensor.numel() * edges.rhs.tensor.element_size() io_bytes += edges.rel.numel() * edges.rel.element_size() log_status("Shuffling edges") # Fix a seed to get the same permutation every time; have it # depend on all and only what affects the set of edges. g = torch.Generator() g.manual_seed(hash((edge_path_idx, edge_chunk_idx, cur_b.lhs, cur_b.rhs))) num_eval_edges = int(num_edges * config.eval_fraction) if num_eval_edges > 0: edge_perm = torch.randperm(num_edges, generator=g) eval_edge_perm = edge_perm[-num_eval_edges:] num_edges -= num_eval_edges edge_perm = edge_perm[torch.randperm(num_edges)] else: edge_perm = torch.randperm(num_edges) # HOGWILD evaluation before training eval_stats_before: Optional[Stats] = None if num_eval_edges > 0: log_status("Waiting for workers to perform evaluation") all_eval_stats_before = pool.map(call, [ partial( process_in_batches, batch_size=eval_batch_size, model=model, batch_processor=evaluator, edges=edges, indices=eval_edge_perm[s], ) for s in split_almost_equally(eval_edge_perm.size(0), num_parts=num_workers) ]) eval_stats_before = Stats.sum(all_eval_stats_before).average() log("stats before %s: %s" % (cur_b, eval_stats_before)) io_time += time.time() - tic tic = time.time() # HOGWILD training log_status("Waiting for workers to perform training") # FIXME should we only delay if iteration_idx == 0? all_stats = pool.map(call, [ partial( process_in_batches, batch_size=config.batch_size, model=model, batch_processor=trainer, edges=edges, indices=edge_perm[s], delay=config.hogwild_delay if epoch_idx == 0 and rank > 0 else 0, ) for rank, s in enumerate(split_almost_equally(edge_perm.size(0), num_parts=num_workers)) ]) stats = Stats.sum(all_stats).average() compute_time = time.time() - tic log_status( "bucket %d / %d : Processed %d edges in %.2f s " "( %.2g M/sec ); io: %.2f s ( %.2f MB/sec )" % (total_buckets - remaining, total_buckets, num_edges, compute_time, num_edges / compute_time / 1e6, io_time, io_bytes / io_time / 1e6), always=True) log_status("%s" % stats, always=True) # HOGWILD eval after training eval_stats_after: Optional[Stats] = None if num_eval_edges > 0: log_status("Waiting for workers to perform evaluation") all_eval_stats_after = pool.map(call, [ partial( process_in_batches, batch_size=eval_batch_size, model=model, batch_processor=evaluator, edges=edges, indices=eval_edge_perm[s], ) for s in split_almost_equally(eval_edge_perm.size(0), num_parts=num_workers) ]) eval_stats_after = Stats.sum(all_eval_stats_after).average() log("stats after %s: %s" % (cur_b, eval_stats_after)) # Add train/eval metrics to queue yield current_index, eval_stats_before, stats, eval_stats_after swap_partitioned_embeddings(cur_b, None) # Distributed Processing: all machines can leave the barrier now. sync.barrier() # Write metadata: for multiple machines, write from rank-0 log("Finished epoch %d path %d pass %d; checkpointing global state." % (epoch_idx + 1, edge_path_idx + 1, edge_chunk_idx + 1)) log("My rank: %d" % rank) if rank == 0: for entity, econfig in config.entities.items(): if econfig.num_partitions == 1: embs = model.get_embeddings(entity, Side.LHS) optimizer = trainer.entity_optimizers[(entity, Partition(0))] checkpoint_manager.write( entity, Partition(0), embs.detach(), OptimizerStateDict(optimizer.state_dict())) sanitized_state_dict: ModuleStateDict = {} for k, v in ModuleStateDict(model.state_dict()).items(): if k.startswith('lhs_embs') or k.startswith('rhs_embs'): # skipping state that's an entity embedding continue sanitized_state_dict[k] = v log("Writing metadata...") checkpoint_manager.write_model( sanitized_state_dict, OptimizerStateDict(trainer.global_optimizer.state_dict()), ) log("Writing the checkpoint...") checkpoint_manager.write_new_version(config) dlog("Waiting for other workers to write their parts of the checkpoint: rank %d" % rank) sync.barrier() dlog("All parts of the checkpoint have been written") log("Switching to new checkpoint version...") checkpoint_manager.switch_to_new_version() dlog("Waiting for other workers to switch to the new checkpoint version: rank %d" % rank) sync.barrier() dlog("All workers have switched to the new checkpoint version") # After all the machines have finished committing # checkpoints, we remove the old checkpoints. checkpoint_manager.remove_old_version(config) # now we're sure that all partition files exist, # so be strict about loading them strict = True # quiescence pool.close() pool.join() sync.barrier() checkpoint_manager.close() if loadpath_manager is not None: loadpath_manager.close() # FIXME join distributed workers (not really necessary) log("Exiting")