示例#1
0
文件: eval.py 项目: delldu/BigGraph
def do_eval_and_report_stats(
    config: ConfigSchema,
    model: Optional[MultiRelationEmbedder] = None,
    evaluator: Optional[AbstractBatchProcessor] = None,
    subprocess_init: Optional[Callable[[], None]] = None,
) -> Generator[Tuple[Optional[int], Optional[Bucket], Stats], None, None]:
    """Computes eval metrics (mr/mrr/r1/r10/r50) for a checkpoint with trained
       embeddings.
    """
    tag_logs_with_process_name(f"Evaluator")

    if evaluator is None:
        evaluator = RankingEvaluator()

    if config.verbose > 0:
        import pprint
        pprint.PrettyPrinter().pprint(config.to_dict())

    checkpoint_manager = CheckpointManager(config.checkpoint_path)

    def load_embeddings(entity: EntityName,
                        part: Partition) -> torch.nn.Parameter:
        embs, _ = checkpoint_manager.read(entity, part)
        assert embs.is_shared()
        return torch.nn.Parameter(embs)

    nparts_lhs, lhs_partitioned_types = get_partitioned_types(config, Side.LHS)
    nparts_rhs, rhs_partitioned_types = get_partitioned_types(config, Side.RHS)

    num_workers = get_num_workers(config.workers)
    pool = create_pool(
        num_workers,
        subprocess_name="EvalWorker",
        subprocess_init=subprocess_init,
    )

    if model is None:
        model = make_model(config)
    model.share_memory()

    state_dict, _ = checkpoint_manager.maybe_read_model()
    if state_dict is not None:
        model.load_state_dict(state_dict, strict=False)

    model.eval()

    for entity, econfig in config.entities.items():
        if econfig.num_partitions == 1:
            embs = load_embeddings(entity, Partition(0))
            model.set_embeddings(entity, embs, Side.LHS)
            model.set_embeddings(entity, embs, Side.RHS)

    all_stats: List[Stats] = []
    for edge_path_idx, edge_path in enumerate(config.edge_paths):
        logger.info(
            f"Starting edge path {edge_path_idx + 1} / {len(config.edge_paths)} "
            f"({edge_path})")
        edge_storage = EDGE_STORAGES.make_instance(edge_path)

        all_edge_path_stats = []
        last_lhs, last_rhs = None, None
        for bucket in create_buckets_ordered_lexicographically(
                nparts_lhs, nparts_rhs):
            tic = time.time()
            # logger.info(f"{bucket}: Loading entities")

            if last_lhs != bucket.lhs:
                for e in lhs_partitioned_types:
                    model.clear_embeddings(e, Side.LHS)
                    embs = load_embeddings(e, bucket.lhs)
                    model.set_embeddings(e, embs, Side.LHS)
            if last_rhs != bucket.rhs:
                for e in rhs_partitioned_types:
                    model.clear_embeddings(e, Side.RHS)
                    embs = load_embeddings(e, bucket.rhs)
                    model.set_embeddings(e, embs, Side.RHS)
            last_lhs, last_rhs = bucket.lhs, bucket.rhs

            # logger.info(f"{bucket}: Loading edges")
            edges = edge_storage.load_edges(bucket.lhs, bucket.rhs)
            num_edges = len(edges)

            load_time = time.time() - tic
            tic = time.time()
            # logger.info(f"{bucket}: Launching and waiting for workers")
            future_all_bucket_stats = pool.map_async(call, [
                partial(
                    process_in_batches,
                    batch_size=config.batch_size,
                    model=model,
                    batch_processor=evaluator,
                    edges=edges[s],
                )
                for s in split_almost_equally(num_edges, num_parts=num_workers)
            ])
            all_bucket_stats = \
                get_async_result(future_all_bucket_stats, pool)

            compute_time = time.time() - tic
            logger.info(
                f"{bucket}: Processed {num_edges} edges in {compute_time:.2g} s "
                f"({num_edges / compute_time / 1e6:.2g}M/sec); "
                f"load time: {load_time:.2g} s")

            total_bucket_stats = Stats.sum(all_bucket_stats)
            all_edge_path_stats.append(total_bucket_stats)
            mean_bucket_stats = total_bucket_stats.average()
            logger.info(
                f"Stats for edge path {edge_path_idx + 1} / {len(config.edge_paths)}, "
                f"bucket {bucket}: {mean_bucket_stats}")

            yield edge_path_idx, bucket, mean_bucket_stats

        total_edge_path_stats = Stats.sum(all_edge_path_stats)
        all_stats.append(total_edge_path_stats)
        mean_edge_path_stats = total_edge_path_stats.average()
        logger.info("")
        logger.info(
            f"Stats for edge path {edge_path_idx + 1} / {len(config.edge_paths)}: "
            f"{mean_edge_path_stats}")
        logger.info("")

        yield edge_path_idx, None, mean_edge_path_stats

    mean_stats = Stats.sum(all_stats).average()
    logger.info("")
    logger.info(f"Stats: {mean_stats}")
    logger.info("")

    yield None, None, mean_stats

    pool.close()
    pool.join()
示例#2
0
def do_eval_and_report_stats(
    config: ConfigSchema,
    model: Optional[MultiRelationEmbedder] = None,
    evaluator: Optional[AbstractBatchProcessor] = None,
    subprocess_init: Optional[Callable[[], None]] = None,
) -> Generator[Tuple[Optional[int], Optional[Bucket], Stats], None, None]:
    """Computes eval metrics (mr/mrr/r1/r10/r50) for a checkpoint with trained
       embeddings.
    """
    tag_logs_with_process_name(f"Evaluator")

    if evaluator is None:
        evaluator = RankingEvaluator(
            loss_fn=LOSS_FUNCTIONS.get_class(
                config.loss_fn)(margin=config.margin),
            relation_weights=[
                relation.weight for relation in config.relations
            ],
        )

    if config.verbose > 0:
        import pprint

        pprint.PrettyPrinter().pprint(config.to_dict())

    checkpoint_manager = CheckpointManager(config.checkpoint_path)

    def load_embeddings(entity: EntityName,
                        part: Partition) -> torch.nn.Parameter:
        embs, _ = checkpoint_manager.read(entity, part)
        assert embs.is_shared()
        return torch.nn.Parameter(embs)

    holder = EmbeddingHolder(config)

    num_workers = get_num_workers(config.workers)
    pool = create_pool(num_workers,
                       subprocess_name="EvalWorker",
                       subprocess_init=subprocess_init)

    if model is None:
        model = make_model(config)
    model.share_memory()

    state_dict, _ = checkpoint_manager.maybe_read_model()
    if state_dict is not None:
        model.load_state_dict(state_dict, strict=False)

    model.eval()

    for entity in holder.lhs_unpartitioned_types | holder.rhs_unpartitioned_types:
        embs = load_embeddings(entity, UNPARTITIONED)
        holder.unpartitioned_embeddings[entity] = embs

    all_stats: List[Stats] = []
    for edge_path_idx, edge_path in enumerate(config.edge_paths):
        logger.info(
            f"Starting edge path {edge_path_idx + 1} / {len(config.edge_paths)} "
            f"({edge_path})")
        edge_storage = EDGE_STORAGES.make_instance(edge_path)

        all_edge_path_stats = []
        # FIXME This order assumes higher affinity on the left-hand side, as it's
        # the one changing more slowly. Make this adaptive to the actual affinity.
        for bucket in create_buckets_ordered_lexicographically(
                holder.nparts_lhs, holder.nparts_rhs):
            tic = time.perf_counter()
            # logger.info(f"{bucket}: Loading entities")

            old_parts = set(holder.partitioned_embeddings.keys())
            new_parts = {(e, bucket.lhs)
                         for e in holder.lhs_partitioned_types
                         } | {(e, bucket.rhs)
                              for e in holder.rhs_partitioned_types}
            for entity, part in old_parts - new_parts:
                del holder.partitioned_embeddings[entity, part]
            for entity, part in new_parts - old_parts:
                embs = load_embeddings(entity, part)
                holder.partitioned_embeddings[entity, part] = embs

            model.set_all_embeddings(holder, bucket)

            # logger.info(f"{bucket}: Loading edges")
            edges = edge_storage.load_edges(bucket.lhs, bucket.rhs)
            num_edges = len(edges)

            load_time = time.perf_counter() - tic
            tic = time.perf_counter()
            # logger.info(f"{bucket}: Launching and waiting for workers")
            future_all_bucket_stats = pool.map_async(
                call,
                [
                    partial(
                        process_in_batches,
                        batch_size=config.batch_size,
                        model=model,
                        batch_processor=evaluator,
                        edges=edges[s],
                    ) for s in split_almost_equally(num_edges,
                                                    num_parts=num_workers)
                ],
            )
            all_bucket_stats = get_async_result(future_all_bucket_stats, pool)

            compute_time = time.perf_counter() - tic
            logger.info(
                f"{bucket}: Processed {num_edges} edges in {compute_time:.2g} s "
                f"({num_edges / compute_time / 1e6:.2g}M/sec); "
                f"load time: {load_time:.2g} s")

            total_bucket_stats = Stats.sum(all_bucket_stats)
            all_edge_path_stats.append(total_bucket_stats)
            mean_bucket_stats = total_bucket_stats.average()
            logger.info(
                f"Stats for edge path {edge_path_idx + 1} / {len(config.edge_paths)}, "
                f"bucket {bucket}: {mean_bucket_stats}")

            model.clear_all_embeddings()

            yield edge_path_idx, bucket, mean_bucket_stats

        total_edge_path_stats = Stats.sum(all_edge_path_stats)
        all_stats.append(total_edge_path_stats)
        mean_edge_path_stats = total_edge_path_stats.average()
        logger.info("")
        logger.info(
            f"Stats for edge path {edge_path_idx + 1} / {len(config.edge_paths)}: "
            f"{mean_edge_path_stats}")
        logger.info("")

        yield edge_path_idx, None, mean_edge_path_stats

    mean_stats = Stats.sum(all_stats).average()
    logger.info("")
    logger.info(f"Stats: {mean_stats}")
    logger.info("")

    yield None, None, mean_stats

    pool.close()
    pool.join()
示例#3
0
def do_eval_and_report_stats(
    config: ConfigSchema,
    model: Optional[MultiRelationEmbedder] = None,
    evaluator: Optional[AbstractBatchProcessor] = None,
) -> Generator[Tuple[Optional[int], Optional[Bucket], Stats], None, None]:
    """Computes eval metrics (mr/mrr/r1/r10/r50) for a checkpoint with trained
       embeddings.
    """

    if evaluator is None:
        evaluator = RankingEvaluator()

    if config.verbose > 0:
        import pprint
        pprint.PrettyPrinter().pprint(config.to_dict())

    checkpoint_manager = CheckpointManager(config.checkpoint_path)

    def load_embeddings(entity: EntityName, part: Partition) -> torch.nn.Parameter:
        embs, _ = checkpoint_manager.read(entity, part)
        assert embs.is_shared()
        return torch.nn.Parameter(embs)

    nparts_lhs, lhs_partitioned_types = get_partitioned_types(config, Side.LHS)
    nparts_rhs, rhs_partitioned_types = get_partitioned_types(config, Side.RHS)

    num_workers = get_num_workers(config.workers)
    pool = create_pool(num_workers)

    if model is None:
        model = make_model(config)
    model.share_memory()

    state_dict, _ = checkpoint_manager.maybe_read_model()
    if state_dict is not None:
        model.load_state_dict(state_dict, strict=False)

    model.eval()

    for entity, econfig in config.entities.items():
        if econfig.num_partitions == 1:
            embs = load_embeddings(entity, Partition(0))
            model.set_embeddings(entity, embs, Side.LHS)
            model.set_embeddings(entity, embs, Side.RHS)

    all_stats: List[Stats] = []
    for edge_path_idx, edge_path in enumerate(config.edge_paths):
        log("Starting edge path %d / %d (%s)"
            % (edge_path_idx + 1, len(config.edge_paths), edge_path))
        edge_reader = EdgeReader(edge_path)

        all_edge_path_stats = []
        last_lhs, last_rhs = None, None
        for bucket in create_buckets_ordered_lexicographically(nparts_lhs, nparts_rhs):
            tic = time.time()
            # log("%s: Loading entities" % (bucket,))

            if last_lhs != bucket.lhs:
                for e in lhs_partitioned_types:
                    model.clear_embeddings(e, Side.LHS)
                    embs = load_embeddings(e, bucket.lhs)
                    model.set_embeddings(e, embs, Side.LHS)
            if last_rhs != bucket.rhs:
                for e in rhs_partitioned_types:
                    model.clear_embeddings(e, Side.RHS)
                    embs = load_embeddings(e, bucket.rhs)
                    model.set_embeddings(e, embs, Side.RHS)
            last_lhs, last_rhs = bucket.lhs, bucket.rhs

            # log("%s: Loading edges" % (bucket,))
            edges = edge_reader.read(bucket.lhs, bucket.rhs)
            num_edges = len(edges)

            load_time = time.time() - tic
            tic = time.time()
            # log("%s: Launching and waiting for workers" % (bucket,))
            all_bucket_stats = pool.map(call, [
                partial(
                    process_in_batches,
                    batch_size=config.batch_size,
                    model=model,
                    batch_processor=evaluator,
                    edges=edges[s],
                )
                for s in split_almost_equally(num_edges, num_parts=num_workers)
            ])

            compute_time = time.time() - tic
            log("%s: Processed %d edges in %.2g s (%.2gM/sec); load time: %.2g s"
                % (bucket, num_edges, compute_time,
                   num_edges / compute_time / 1e6, load_time))

            total_bucket_stats = Stats.sum(all_bucket_stats)
            all_edge_path_stats.append(total_bucket_stats)
            mean_bucket_stats = total_bucket_stats.average()
            log("Stats for edge path %d / %d, bucket %s: %s"
                % (edge_path_idx + 1, len(config.edge_paths), bucket,
                   mean_bucket_stats))

            yield edge_path_idx, bucket, mean_bucket_stats

        total_edge_path_stats = Stats.sum(all_edge_path_stats)
        all_stats.append(total_edge_path_stats)
        mean_edge_path_stats = total_edge_path_stats.average()
        log("")
        log("Stats for edge path %d / %d: %s"
            % (edge_path_idx + 1, len(config.edge_paths), mean_edge_path_stats))
        log("")

        yield edge_path_idx, None, mean_edge_path_stats

    mean_stats = Stats.sum(all_stats).average()
    log("")
    log("Stats: %s" % mean_stats)
    log("")

    yield None, None, mean_stats

    pool.close()
    pool.join()