def __init__( self, global_optimizer: Optimizer, loss_fn: str, margin: float, relations: List[RelationSchema], ) -> None: super().__init__() self.global_optimizer = global_optimizer self.entity_optimizers: Dict[Tuple[EntityName, Partition], Optimizer] = {} loss_fn_class = LOSS_FUNCTIONS.get_class(loss_fn) # TODO This is awful! Can we do better? if loss_fn == "ranking": self.loss_fn = loss_fn_class(margin) else: self.loss_fn = loss_fn_class() self.relations = relations
def __init__(self, config: ConfigSchema, filter_paths: List[str]) -> None: loss_fn = LOSS_FUNCTIONS.get_class( config.loss_fn)(margin=config.margin) relation_weights = [r.weight for r in config.relations] super().__init__(loss_fn, relation_weights) if len(config.relations) != 1 or len(config.entities) != 1: raise RuntimeError( "Filtered ranking evaluation should only be used " "with dynamic relations and one entity type.") if not config.relations[0].all_negs: raise RuntimeError( "Filtered Eval can only be done with all negatives.") (entity, ) = config.entities.values() if entity.featurized: raise RuntimeError( "Entity cannot be featurized for filtered eval.") if entity.num_partitions > 1: raise RuntimeError( "Entity cannot be partitioned for filtered eval.") self.lhs_map: Dict[Tuple[int, int], List[int]] = defaultdict(list) self.rhs_map: Dict[Tuple[int, int], List[int]] = defaultdict(list) for path in filter_paths: logger.info(f"Building links map from path {path}") e_storage = EDGE_STORAGES.make_instance(path) # Assume unpartitioned. edges = e_storage.load_edges(UNPARTITIONED, UNPARTITIONED) for idx in range(len(edges)): # Assume non-featurized. cur_lhs = int(edges.lhs.to_tensor()[idx]) # Assume dynamic relations. cur_rel = int(edges.rel[idx]) # Assume non-featurized. cur_rhs = int(edges.rhs.to_tensor()[idx]) self.lhs_map[cur_lhs, cur_rel].append(cur_rhs) self.rhs_map[cur_rhs, cur_rel].append(cur_lhs) logger.info(f"Done building links map from path {path}")
def __init__( # noqa self, config: ConfigSchema, model: Optional[MultiRelationEmbedder] = None, trainer: Optional[AbstractBatchProcessor] = None, evaluator: Optional[AbstractBatchProcessor] = None, rank: Rank = SINGLE_TRAINER, subprocess_init: Optional[Callable[[], None]] = None, stats_handler: StatsHandler = NOOP_STATS_HANDLER, ): """Each epoch/pass, for each partition pair, loads in embeddings and edgelist from disk, runs HOGWILD training on them, and writes partitions back to disk. """ tag_logs_with_process_name(f"Trainer-{rank}") self.config = config if config.verbose > 0: import pprint pprint.PrettyPrinter().pprint(config.to_dict()) logger.info("Loading entity counts...") entity_storage = ENTITY_STORAGES.make_instance(config.entity_path) entity_counts: Dict[str, List[int]] = {} for entity, econf in config.entities.items(): entity_counts[entity] = [] for part in range(econf.num_partitions): entity_counts[entity].append(entity_storage.load_count(entity, part)) # Figure out how many lhs and rhs partitions we need holder = self.holder = EmbeddingHolder(config) logger.debug( f"nparts {holder.nparts_lhs} {holder.nparts_rhs} " f"types {holder.lhs_partitioned_types} {holder.rhs_partitioned_types}" ) # We know ahead of time that we wil need 1-2 storages for each embedding type, # as well as the max size of this storage (num_entities x D). # We allocate these storages n advance in `embedding_storage_freelist`. # When we need storage for an entity type, we pop it from this free list, # and then add it back when we 'delete' the embedding table. embedding_storage_freelist: Dict[ EntityName, Set[torch.FloatStorage] ] = defaultdict(set) for entity_type, counts in entity_counts.items(): max_count = max(counts) num_sides = ( (1 if entity_type in holder.lhs_partitioned_types else 0) + (1 if entity_type in holder.rhs_partitioned_types else 0) + ( 1 if entity_type in (holder.lhs_unpartitioned_types | holder.rhs_unpartitioned_types) else 0 ) ) for _ in range(num_sides): embedding_storage_freelist[entity_type].add( allocate_shared_tensor( (max_count, config.entity_dimension(entity_type)), dtype=torch.float, ).storage() ) # create the handlers, threads, etc. for distributed training if config.num_machines > 1 or config.num_partition_servers > 0: if not 0 <= rank < config.num_machines: raise RuntimeError("Invalid rank for trainer") if not td.is_available(): raise RuntimeError( "The installed PyTorch version doesn't provide " "distributed training capabilities." ) ranks = ProcessRanks.from_num_invocations( config.num_machines, config.num_partition_servers ) num_ps_groups = config.num_groups_for_partition_server groups: List[List[int]] = [ranks.trainers] # barrier group groups += [ ranks.trainers + ranks.partition_servers ] * num_ps_groups # ps groups group_idxs_for_partition_servers = range(1, len(groups)) if rank == SINGLE_TRAINER: logger.info("Setup lock server...") start_server( LockServer( num_clients=len(ranks.trainers), nparts_lhs=holder.nparts_lhs, nparts_rhs=holder.nparts_rhs, entities_lhs=holder.lhs_partitioned_types, entities_rhs=holder.rhs_partitioned_types, entity_counts=entity_counts, init_tree=config.distributed_tree_init_order, stats_handler=stats_handler, ), process_name="LockServer", init_method=config.distributed_init_method, world_size=ranks.world_size, server_rank=ranks.lock_server, groups=groups, subprocess_init=subprocess_init, ) self.bucket_scheduler = DistributedBucketScheduler( server_rank=ranks.lock_server, client_rank=ranks.trainers[rank] ) logger.info("Setup param server...") start_server( ParameterServer(num_clients=len(ranks.trainers)), process_name=f"ParamS-{rank}", init_method=config.distributed_init_method, world_size=ranks.world_size, server_rank=ranks.parameter_servers[rank], groups=groups, subprocess_init=subprocess_init, ) parameter_sharer = ParameterSharer( process_name=f"ParamC-{rank}", client_rank=ranks.parameter_clients[rank], all_server_ranks=ranks.parameter_servers, init_method=config.distributed_init_method, world_size=ranks.world_size, groups=groups, subprocess_init=subprocess_init, ) if config.num_partition_servers == -1: start_server( ParameterServer( num_clients=len(ranks.trainers), group_idxs=group_idxs_for_partition_servers, log_stats=True, ), process_name=f"PartS-{rank}", init_method=config.distributed_init_method, world_size=ranks.world_size, server_rank=ranks.partition_servers[rank], groups=groups, subprocess_init=subprocess_init, ) groups = init_process_group( rank=ranks.trainers[rank], world_size=ranks.world_size, init_method=config.distributed_init_method, groups=groups, ) trainer_group, *groups_for_partition_servers = groups self.barrier_group = trainer_group if len(ranks.partition_servers) > 0: partition_client = PartitionClient( ranks.partition_servers, groups=groups_for_partition_servers, log_stats=True, ) else: partition_client = None else: self.barrier_group = None self.bucket_scheduler = SingleMachineBucketScheduler( holder.nparts_lhs, holder.nparts_rhs, config.bucket_order, stats_handler ) parameter_sharer = None partition_client = None hide_distributed_logging() # fork early for HOGWILD threads logger.info("Creating workers...") self.num_workers = get_num_workers(config.workers) self.pool = create_pool( self.num_workers, subprocess_name=f"TWorker-{rank}", subprocess_init=subprocess_init, ) checkpoint_manager = CheckpointManager( config.checkpoint_path, rank=rank, num_machines=config.num_machines, partition_client=partition_client, subprocess_name=f"BackgRW-{rank}", subprocess_init=subprocess_init, ) self.checkpoint_manager = checkpoint_manager checkpoint_manager.register_metadata_provider(ConfigMetadataProvider(config)) if rank == 0: checkpoint_manager.write_config(config) num_edge_chunks = get_num_edge_chunks(config) self.iteration_manager = IterationManager( config.num_epochs, config.edge_paths, num_edge_chunks, iteration_idx=checkpoint_manager.checkpoint_version, ) checkpoint_manager.register_metadata_provider(self.iteration_manager) logger.info("Initializing global model...") if model is None: model = make_model(config) model.share_memory() loss_fn = LOSS_FUNCTIONS.get_class(config.loss_fn)(margin=config.margin) relation_weights = [relation.weight for relation in config.relations] if trainer is None: trainer = Trainer( model_optimizer=make_optimizer(config, model.parameters(), False), loss_fn=loss_fn, relation_weights=relation_weights, ) if evaluator is None: eval_overrides = {} if config.eval_num_batch_negs is not None: eval_overrides["num_batch_negs"] = config.eval_num_batch_negs if config.eval_num_uniform_negs is not None: eval_overrides["num_uniform_negs"] = config.eval_num_uniform_negs evaluator = RankingEvaluator( loss_fn=loss_fn, relation_weights=relation_weights, overrides=eval_overrides, ) if config.init_path is not None: self.loadpath_manager = CheckpointManager(config.init_path) else: self.loadpath_manager = None # load model from checkpoint or loadpath, if available state_dict, optim_state = checkpoint_manager.maybe_read_model() if state_dict is None and self.loadpath_manager is not None: state_dict, optim_state = self.loadpath_manager.maybe_read_model() if state_dict is not None: model.load_state_dict(state_dict, strict=False) if optim_state is not None: trainer.model_optimizer.load_state_dict(optim_state) logger.debug("Loading unpartitioned entities...") for entity in holder.lhs_unpartitioned_types | holder.rhs_unpartitioned_types: count = entity_counts[entity][0] s = embedding_storage_freelist[entity].pop() dimension = config.entity_dimension(entity) embs = torch.FloatTensor(s).view(-1, dimension)[:count] embs, optimizer = self._load_embeddings(entity, UNPARTITIONED, out=embs) holder.unpartitioned_embeddings[entity] = embs trainer.unpartitioned_optimizers[entity] = optimizer # start communicating shared parameters with the parameter server if parameter_sharer is not None: shared_parameters: Set[int] = set() for name, param in model.named_parameters(): if id(param) in shared_parameters: continue shared_parameters.add(id(param)) key = f"model.{name}" logger.info( f"Adding {key} ({param.numel()} params) to parameter server" ) parameter_sharer.set_param(key, param.data) for entity, embs in holder.unpartitioned_embeddings.items(): key = f"entity.{entity}" logger.info(f"Adding {key} ({embs.numel()} params) to parameter server") parameter_sharer.set_param(key, embs.data) # store everything in self self.model = model self.trainer = trainer self.evaluator = evaluator self.rank = rank self.entity_counts = entity_counts self.embedding_storage_freelist = embedding_storage_freelist self.stats_handler = stats_handler self.strict = False
def do_eval_and_report_stats( config: ConfigSchema, model: Optional[MultiRelationEmbedder] = None, evaluator: Optional[AbstractBatchProcessor] = None, subprocess_init: Optional[Callable[[], None]] = None, ) -> Generator[Tuple[Optional[int], Optional[Bucket], Stats], None, None]: """Computes eval metrics (mr/mrr/r1/r10/r50) for a checkpoint with trained embeddings. """ tag_logs_with_process_name(f"Evaluator") if evaluator is None: evaluator = RankingEvaluator( loss_fn=LOSS_FUNCTIONS.get_class( config.loss_fn)(margin=config.margin), relation_weights=[ relation.weight for relation in config.relations ], ) if config.verbose > 0: import pprint pprint.PrettyPrinter().pprint(config.to_dict()) checkpoint_manager = CheckpointManager(config.checkpoint_path) def load_embeddings(entity: EntityName, part: Partition) -> torch.nn.Parameter: embs, _ = checkpoint_manager.read(entity, part) assert embs.is_shared() return torch.nn.Parameter(embs) holder = EmbeddingHolder(config) num_workers = get_num_workers(config.workers) pool = create_pool(num_workers, subprocess_name="EvalWorker", subprocess_init=subprocess_init) if model is None: model = make_model(config) model.share_memory() state_dict, _ = checkpoint_manager.maybe_read_model() if state_dict is not None: model.load_state_dict(state_dict, strict=False) model.eval() for entity in holder.lhs_unpartitioned_types | holder.rhs_unpartitioned_types: embs = load_embeddings(entity, UNPARTITIONED) holder.unpartitioned_embeddings[entity] = embs all_stats: List[Stats] = [] for edge_path_idx, edge_path in enumerate(config.edge_paths): logger.info( f"Starting edge path {edge_path_idx + 1} / {len(config.edge_paths)} " f"({edge_path})") edge_storage = EDGE_STORAGES.make_instance(edge_path) all_edge_path_stats = [] # FIXME This order assumes higher affinity on the left-hand side, as it's # the one changing more slowly. Make this adaptive to the actual affinity. for bucket in create_buckets_ordered_lexicographically( holder.nparts_lhs, holder.nparts_rhs): tic = time.perf_counter() # logger.info(f"{bucket}: Loading entities") old_parts = set(holder.partitioned_embeddings.keys()) new_parts = {(e, bucket.lhs) for e in holder.lhs_partitioned_types } | {(e, bucket.rhs) for e in holder.rhs_partitioned_types} for entity, part in old_parts - new_parts: del holder.partitioned_embeddings[entity, part] for entity, part in new_parts - old_parts: embs = load_embeddings(entity, part) holder.partitioned_embeddings[entity, part] = embs model.set_all_embeddings(holder, bucket) # logger.info(f"{bucket}: Loading edges") edges = edge_storage.load_edges(bucket.lhs, bucket.rhs) num_edges = len(edges) load_time = time.perf_counter() - tic tic = time.perf_counter() # logger.info(f"{bucket}: Launching and waiting for workers") future_all_bucket_stats = pool.map_async( call, [ partial( process_in_batches, batch_size=config.batch_size, model=model, batch_processor=evaluator, edges=edges[s], ) for s in split_almost_equally(num_edges, num_parts=num_workers) ], ) all_bucket_stats = get_async_result(future_all_bucket_stats, pool) compute_time = time.perf_counter() - tic logger.info( f"{bucket}: Processed {num_edges} edges in {compute_time:.2g} s " f"({num_edges / compute_time / 1e6:.2g}M/sec); " f"load time: {load_time:.2g} s") total_bucket_stats = Stats.sum(all_bucket_stats) all_edge_path_stats.append(total_bucket_stats) mean_bucket_stats = total_bucket_stats.average() logger.info( f"Stats for edge path {edge_path_idx + 1} / {len(config.edge_paths)}, " f"bucket {bucket}: {mean_bucket_stats}") model.clear_all_embeddings() yield edge_path_idx, bucket, mean_bucket_stats total_edge_path_stats = Stats.sum(all_edge_path_stats) all_stats.append(total_edge_path_stats) mean_edge_path_stats = total_edge_path_stats.average() logger.info("") logger.info( f"Stats for edge path {edge_path_idx + 1} / {len(config.edge_paths)}: " f"{mean_edge_path_stats}") logger.info("") yield edge_path_idx, None, mean_edge_path_stats mean_stats = Stats.sum(all_stats).average() logger.info("") logger.info(f"Stats: {mean_stats}") logger.info("") yield None, None, mean_stats pool.close() pool.join()