def do_eval_and_report_stats( config: ConfigSchema, model: Optional[MultiRelationEmbedder] = None, evaluator: Optional[AbstractBatchProcessor] = None, subprocess_init: Optional[Callable[[], None]] = None, ) -> Generator[Tuple[Optional[int], Optional[Bucket], Stats], None, None]: """Computes eval metrics (mr/mrr/r1/r10/r50) for a checkpoint with trained embeddings. """ tag_logs_with_process_name(f"Evaluator") if evaluator is None: evaluator = RankingEvaluator() if config.verbose > 0: import pprint pprint.PrettyPrinter().pprint(config.to_dict()) checkpoint_manager = CheckpointManager(config.checkpoint_path) def load_embeddings(entity: EntityName, part: Partition) -> torch.nn.Parameter: embs, _ = checkpoint_manager.read(entity, part) assert embs.is_shared() return torch.nn.Parameter(embs) nparts_lhs, lhs_partitioned_types = get_partitioned_types(config, Side.LHS) nparts_rhs, rhs_partitioned_types = get_partitioned_types(config, Side.RHS) num_workers = get_num_workers(config.workers) pool = create_pool( num_workers, subprocess_name="EvalWorker", subprocess_init=subprocess_init, ) if model is None: model = make_model(config) model.share_memory() state_dict, _ = checkpoint_manager.maybe_read_model() if state_dict is not None: model.load_state_dict(state_dict, strict=False) model.eval() for entity, econfig in config.entities.items(): if econfig.num_partitions == 1: embs = load_embeddings(entity, Partition(0)) model.set_embeddings(entity, embs, Side.LHS) model.set_embeddings(entity, embs, Side.RHS) all_stats: List[Stats] = [] for edge_path_idx, edge_path in enumerate(config.edge_paths): logger.info( f"Starting edge path {edge_path_idx + 1} / {len(config.edge_paths)} " f"({edge_path})") edge_storage = EDGE_STORAGES.make_instance(edge_path) all_edge_path_stats = [] last_lhs, last_rhs = None, None for bucket in create_buckets_ordered_lexicographically( nparts_lhs, nparts_rhs): tic = time.time() # logger.info(f"{bucket}: Loading entities") if last_lhs != bucket.lhs: for e in lhs_partitioned_types: model.clear_embeddings(e, Side.LHS) embs = load_embeddings(e, bucket.lhs) model.set_embeddings(e, embs, Side.LHS) if last_rhs != bucket.rhs: for e in rhs_partitioned_types: model.clear_embeddings(e, Side.RHS) embs = load_embeddings(e, bucket.rhs) model.set_embeddings(e, embs, Side.RHS) last_lhs, last_rhs = bucket.lhs, bucket.rhs # logger.info(f"{bucket}: Loading edges") edges = edge_storage.load_edges(bucket.lhs, bucket.rhs) num_edges = len(edges) load_time = time.time() - tic tic = time.time() # logger.info(f"{bucket}: Launching and waiting for workers") future_all_bucket_stats = pool.map_async(call, [ partial( process_in_batches, batch_size=config.batch_size, model=model, batch_processor=evaluator, edges=edges[s], ) for s in split_almost_equally(num_edges, num_parts=num_workers) ]) all_bucket_stats = \ get_async_result(future_all_bucket_stats, pool) compute_time = time.time() - tic logger.info( f"{bucket}: Processed {num_edges} edges in {compute_time:.2g} s " f"({num_edges / compute_time / 1e6:.2g}M/sec); " f"load time: {load_time:.2g} s") total_bucket_stats = Stats.sum(all_bucket_stats) all_edge_path_stats.append(total_bucket_stats) mean_bucket_stats = total_bucket_stats.average() logger.info( f"Stats for edge path {edge_path_idx + 1} / {len(config.edge_paths)}, " f"bucket {bucket}: {mean_bucket_stats}") yield edge_path_idx, bucket, mean_bucket_stats total_edge_path_stats = Stats.sum(all_edge_path_stats) all_stats.append(total_edge_path_stats) mean_edge_path_stats = total_edge_path_stats.average() logger.info("") logger.info( f"Stats for edge path {edge_path_idx + 1} / {len(config.edge_paths)}: " f"{mean_edge_path_stats}") logger.info("") yield edge_path_idx, None, mean_edge_path_stats mean_stats = Stats.sum(all_stats).average() logger.info("") logger.info(f"Stats: {mean_stats}") logger.info("") yield None, None, mean_stats pool.close() pool.join()
def do_eval_and_report_stats( config: ConfigSchema, model: Optional[MultiRelationEmbedder] = None, evaluator: Optional[AbstractBatchProcessor] = None, subprocess_init: Optional[Callable[[], None]] = None, ) -> Generator[Tuple[Optional[int], Optional[Bucket], Stats], None, None]: """Computes eval metrics (mr/mrr/r1/r10/r50) for a checkpoint with trained embeddings. """ tag_logs_with_process_name(f"Evaluator") if evaluator is None: evaluator = RankingEvaluator( loss_fn=LOSS_FUNCTIONS.get_class( config.loss_fn)(margin=config.margin), relation_weights=[ relation.weight for relation in config.relations ], ) if config.verbose > 0: import pprint pprint.PrettyPrinter().pprint(config.to_dict()) checkpoint_manager = CheckpointManager(config.checkpoint_path) def load_embeddings(entity: EntityName, part: Partition) -> torch.nn.Parameter: embs, _ = checkpoint_manager.read(entity, part) assert embs.is_shared() return torch.nn.Parameter(embs) holder = EmbeddingHolder(config) num_workers = get_num_workers(config.workers) pool = create_pool(num_workers, subprocess_name="EvalWorker", subprocess_init=subprocess_init) if model is None: model = make_model(config) model.share_memory() state_dict, _ = checkpoint_manager.maybe_read_model() if state_dict is not None: model.load_state_dict(state_dict, strict=False) model.eval() for entity in holder.lhs_unpartitioned_types | holder.rhs_unpartitioned_types: embs = load_embeddings(entity, UNPARTITIONED) holder.unpartitioned_embeddings[entity] = embs all_stats: List[Stats] = [] for edge_path_idx, edge_path in enumerate(config.edge_paths): logger.info( f"Starting edge path {edge_path_idx + 1} / {len(config.edge_paths)} " f"({edge_path})") edge_storage = EDGE_STORAGES.make_instance(edge_path) all_edge_path_stats = [] # FIXME This order assumes higher affinity on the left-hand side, as it's # the one changing more slowly. Make this adaptive to the actual affinity. for bucket in create_buckets_ordered_lexicographically( holder.nparts_lhs, holder.nparts_rhs): tic = time.perf_counter() # logger.info(f"{bucket}: Loading entities") old_parts = set(holder.partitioned_embeddings.keys()) new_parts = {(e, bucket.lhs) for e in holder.lhs_partitioned_types } | {(e, bucket.rhs) for e in holder.rhs_partitioned_types} for entity, part in old_parts - new_parts: del holder.partitioned_embeddings[entity, part] for entity, part in new_parts - old_parts: embs = load_embeddings(entity, part) holder.partitioned_embeddings[entity, part] = embs model.set_all_embeddings(holder, bucket) # logger.info(f"{bucket}: Loading edges") edges = edge_storage.load_edges(bucket.lhs, bucket.rhs) num_edges = len(edges) load_time = time.perf_counter() - tic tic = time.perf_counter() # logger.info(f"{bucket}: Launching and waiting for workers") future_all_bucket_stats = pool.map_async( call, [ partial( process_in_batches, batch_size=config.batch_size, model=model, batch_processor=evaluator, edges=edges[s], ) for s in split_almost_equally(num_edges, num_parts=num_workers) ], ) all_bucket_stats = get_async_result(future_all_bucket_stats, pool) compute_time = time.perf_counter() - tic logger.info( f"{bucket}: Processed {num_edges} edges in {compute_time:.2g} s " f"({num_edges / compute_time / 1e6:.2g}M/sec); " f"load time: {load_time:.2g} s") total_bucket_stats = Stats.sum(all_bucket_stats) all_edge_path_stats.append(total_bucket_stats) mean_bucket_stats = total_bucket_stats.average() logger.info( f"Stats for edge path {edge_path_idx + 1} / {len(config.edge_paths)}, " f"bucket {bucket}: {mean_bucket_stats}") model.clear_all_embeddings() yield edge_path_idx, bucket, mean_bucket_stats total_edge_path_stats = Stats.sum(all_edge_path_stats) all_stats.append(total_edge_path_stats) mean_edge_path_stats = total_edge_path_stats.average() logger.info("") logger.info( f"Stats for edge path {edge_path_idx + 1} / {len(config.edge_paths)}: " f"{mean_edge_path_stats}") logger.info("") yield edge_path_idx, None, mean_edge_path_stats mean_stats = Stats.sum(all_stats).average() logger.info("") logger.info(f"Stats: {mean_stats}") logger.info("") yield None, None, mean_stats pool.close() pool.join()
def do_eval_and_report_stats( config: ConfigSchema, model: Optional[MultiRelationEmbedder] = None, evaluator: Optional[AbstractBatchProcessor] = None, ) -> Generator[Tuple[Optional[int], Optional[Bucket], Stats], None, None]: """Computes eval metrics (mr/mrr/r1/r10/r50) for a checkpoint with trained embeddings. """ if evaluator is None: evaluator = RankingEvaluator() if config.verbose > 0: import pprint pprint.PrettyPrinter().pprint(config.to_dict()) checkpoint_manager = CheckpointManager(config.checkpoint_path) def load_embeddings(entity: EntityName, part: Partition) -> torch.nn.Parameter: embs, _ = checkpoint_manager.read(entity, part) assert embs.is_shared() return torch.nn.Parameter(embs) nparts_lhs, lhs_partitioned_types = get_partitioned_types(config, Side.LHS) nparts_rhs, rhs_partitioned_types = get_partitioned_types(config, Side.RHS) num_workers = get_num_workers(config.workers) pool = create_pool(num_workers) if model is None: model = make_model(config) model.share_memory() state_dict, _ = checkpoint_manager.maybe_read_model() if state_dict is not None: model.load_state_dict(state_dict, strict=False) model.eval() for entity, econfig in config.entities.items(): if econfig.num_partitions == 1: embs = load_embeddings(entity, Partition(0)) model.set_embeddings(entity, embs, Side.LHS) model.set_embeddings(entity, embs, Side.RHS) all_stats: List[Stats] = [] for edge_path_idx, edge_path in enumerate(config.edge_paths): log("Starting edge path %d / %d (%s)" % (edge_path_idx + 1, len(config.edge_paths), edge_path)) edge_reader = EdgeReader(edge_path) all_edge_path_stats = [] last_lhs, last_rhs = None, None for bucket in create_buckets_ordered_lexicographically(nparts_lhs, nparts_rhs): tic = time.time() # log("%s: Loading entities" % (bucket,)) if last_lhs != bucket.lhs: for e in lhs_partitioned_types: model.clear_embeddings(e, Side.LHS) embs = load_embeddings(e, bucket.lhs) model.set_embeddings(e, embs, Side.LHS) if last_rhs != bucket.rhs: for e in rhs_partitioned_types: model.clear_embeddings(e, Side.RHS) embs = load_embeddings(e, bucket.rhs) model.set_embeddings(e, embs, Side.RHS) last_lhs, last_rhs = bucket.lhs, bucket.rhs # log("%s: Loading edges" % (bucket,)) edges = edge_reader.read(bucket.lhs, bucket.rhs) num_edges = len(edges) load_time = time.time() - tic tic = time.time() # log("%s: Launching and waiting for workers" % (bucket,)) all_bucket_stats = pool.map(call, [ partial( process_in_batches, batch_size=config.batch_size, model=model, batch_processor=evaluator, edges=edges[s], ) for s in split_almost_equally(num_edges, num_parts=num_workers) ]) compute_time = time.time() - tic log("%s: Processed %d edges in %.2g s (%.2gM/sec); load time: %.2g s" % (bucket, num_edges, compute_time, num_edges / compute_time / 1e6, load_time)) total_bucket_stats = Stats.sum(all_bucket_stats) all_edge_path_stats.append(total_bucket_stats) mean_bucket_stats = total_bucket_stats.average() log("Stats for edge path %d / %d, bucket %s: %s" % (edge_path_idx + 1, len(config.edge_paths), bucket, mean_bucket_stats)) yield edge_path_idx, bucket, mean_bucket_stats total_edge_path_stats = Stats.sum(all_edge_path_stats) all_stats.append(total_edge_path_stats) mean_edge_path_stats = total_edge_path_stats.average() log("") log("Stats for edge path %d / %d: %s" % (edge_path_idx + 1, len(config.edge_paths), mean_edge_path_stats)) log("") yield edge_path_idx, None, mean_edge_path_stats mean_stats = Stats.sum(all_stats).average() log("") log("Stats: %s" % mean_stats) log("") yield None, None, mean_stats pool.close() pool.join()