示例#1
0
def StreamInferenceDoc(args, model, fn, prefix, f, is_query_inference = True, load_cache=False):
    inference_batch_size = args.per_gpu_eval_batch_size #* max(1, args.n_gpu)
    #inference_dataloader = StreamingDataLoader(f, fn, batch_size=inference_batch_size, num_workers=1)
    inference_dataset = StreamingDataset(f, fn)
    inference_dataloader = DataLoader(inference_dataset, batch_size=inference_batch_size)

    if args.local_rank != -1:
        dist.barrier() # directory created

    if (args.emb_file_multi_split_num > 0) and ("passage" in prefix):
        # extra handling the memory problem by specifying the size of file
        _, _ = InferenceEmbeddingFromStreamDataLoader(args, model, inference_dataloader, is_query_inference = is_query_inference, prefix = prefix)
        # dist.barrier()
        full_embedding = None
        full_embedding2id = None # TODO: loading ids for first_worker()
    else:
        if load_cache:
            _embedding = None
            _embedding2id = None
        else:
            _embedding, _embedding2id = InferenceEmbeddingFromStreamDataLoader(args, model, inference_dataloader, is_query_inference = is_query_inference, prefix = prefix)

        not_loading = args.split_ann_search and ("passage" in prefix)
        # preserve to memory
        full_embedding = barrier_array_merge(args, _embedding, prefix = prefix + "_emb_p_", load_cache = load_cache, only_load_in_master = True,not_loading=not_loading)
        _embedding=None
        del _embedding
        full_embedding2id = barrier_array_merge(args, _embedding2id, prefix = prefix + "_embid_p_", load_cache = load_cache, only_load_in_master = True,not_loading=not_loading)
        logger.info( f"finish saving embbedding of {prefix}, not loading into MEM: {not_loading}" )
        _embedding2id=None
        del  _embedding2id

    return full_embedding, full_embedding2id
def StreamInferenceDoc(args, model, fn, prefix, f, is_query_inference=True):
    inference_batch_size = args.per_gpu_eval_batch_size  # * max(1, args.n_gpu)
    inference_dataset = StreamingDataset(f, fn)
    inference_dataloader = DataLoader(
        inference_dataset,
        batch_size=inference_batch_size)

    if args.local_rank != -1:
        dist.barrier()  # directory created

    _embedding, _embedding2id = InferenceEmbeddingFromStreamDataLoader(
        args, model, inference_dataloader, is_query_inference=is_query_inference, prefix=prefix)

    logger.info("merging embeddings")

    # preserve to memory
    full_embedding = barrier_array_merge(
        args,
        _embedding,
        prefix=prefix +
        "_emb_p_",
        load_cache=False,
        only_load_in_master=True)
    full_embedding2id = barrier_array_merge(
        args,
        _embedding2id,
        prefix=prefix +
        "_embid_p_",
        load_cache=False,
        only_load_in_master=True)

    return full_embedding, full_embedding2id
def StreamInferenceDoc(args, model, fn, prefix, f, is_query_inference=True):
    inference_batch_size = args.per_gpu_eval_batch_size  # * max(1, args.n_gpu)
    inference_dataset = StreamingDataset(f, fn)
    inference_dataloader = DataLoader(inference_dataset,
                                      batch_size=inference_batch_size)

    if args.local_rank != -1:
        dist.barrier()  # directory created

    _embedding, _embedding2id = InferenceEmbeddingFromStreamDataLoader(
        args,
        model,
        inference_dataloader,
        is_query_inference=is_query_inference,
        prefix=prefix)

    logger.info("merging embeddings")

    not_loading = args.split_ann_search and ("passage" in prefix)

    # preserve to memory
    full_embedding = barrier_array_merge(args,
                                         _embedding,
                                         prefix=prefix + "_emb_p_",
                                         load_cache=False,
                                         only_load_in_master=True,
                                         not_loading=not_loading)
    _embedding = None
    del _embedding
    logger.info(
        f"finish saving embbedding of {prefix}, not loading into MEM: {not_loading}"
    )

    full_embedding2id = barrier_array_merge(args,
                                            _embedding2id,
                                            prefix=prefix + "_embid_p_",
                                            load_cache=False,
                                            only_load_in_master=True)

    return full_embedding, full_embedding2id