コード例 #1
0
ファイル: train.py プロジェクト: AnReu/ColBERT-for-Formulas
def main():
    random.seed(12345)
    torch.manual_seed(1)

    parser = ArgumentParser(
        description=
        'Training ColBERT with <query, positive passage, negative passage> triples.'
    )

    parser.add_argument('--lr', dest='lr', default=3e-06, type=float)
    parser.add_argument('--maxsteps',
                        dest='maxsteps',
                        default=400000,
                        type=int)
    parser.add_argument('--bsize', dest='bsize', default=32, type=int)
    parser.add_argument('--accum', dest='accumsteps', default=2, type=int)

    parser.add_argument('--data_dir',
                        dest='data_dir',
                        default=DEFAULT_DATA_DIR)
    parser.add_argument('--triples',
                        dest='triples',
                        default='triples.train.small.tsv')
    parser.add_argument('--output_dir',
                        dest='output_dir',
                        default='outputs.train/')

    parser.add_argument('--similarity',
                        dest='similarity',
                        default='cosine',
                        choices=['cosine', 'l2'])
    parser.add_argument('--dim', dest='dim', default=128, type=int)
    parser.add_argument('--query_maxlen',
                        dest='query_maxlen',
                        default=32,
                        type=int)
    parser.add_argument('--doc_maxlen',
                        dest='doc_maxlen',
                        default=180,
                        type=int)

    # TODO: Add resume functionality
    # TODO: Save the configuration to the checkpoint.
    # TODO: Extract common parser arguments/behavior into a class.

    args = parser.parse_args()
    args.input_arguments = args

    create_directory(args.output_dir)

    assert args.bsize % args.accumsteps == 0, ((
        args.bsize, args.accumsteps
    ), "The batch size must be divisible by the number of gradient accumulation steps."
                                               )
    assert args.query_maxlen <= 512
    assert args.doc_maxlen <= 512

    args.triples = os.path.join(args.data_dir, args.triples)

    train(args)
コード例 #2
0
def test_create_directory_relative(tmp_path):
    if not (tmp_path / 'tempfile').exists():
        with utils.working_directory(tmp_path):
            utils.create_directory('tempfile')
        assert (tmp_path / 'tempfile').exists()
    else:
        assert False
コード例 #3
0
def main():
    random.seed(123456)

    parser = ArgumentParser(
        description=
        'Exhaustive (non-index-based) evaluation of re-ranking with ColBERT.')

    parser.add_argument('--index', dest='index', required=True)
    parser.add_argument('--checkpoint', dest='checkpoint', required=True)
    parser.add_argument('--collection',
                        dest='collection',
                        default='collection.tsv')

    parser.add_argument('--data_dir',
                        dest='data_dir',
                        default=DEFAULT_DATA_DIR)
    parser.add_argument('--output_dir',
                        dest='output_dir',
                        default='outputs.index/')

    parser.add_argument('--bsize', dest='bsize', default=128, type=int)
    parser.add_argument('--bytes',
                        dest='bytes',
                        default=2,
                        choices=[2, 4],
                        type=int)
    parser.add_argument('--subsample', dest='subsample',
                        default=None)  # TODO: Add this

    # TODO: For the following four arguments, default should be None. If None, they should be loaded from checkpoint.
    parser.add_argument('--similarity',
                        dest='similarity',
                        default='cosine',
                        choices=['cosine', 'l2'])
    parser.add_argument('--dim', dest='dim', default=128, type=int)
    parser.add_argument('--query_maxlen',
                        dest='query_maxlen',
                        default=32,
                        type=int)
    parser.add_argument('--doc_maxlen',
                        dest='doc_maxlen',
                        default=180,
                        type=int)

    # TODO: Add resume functionality

    args = parser.parse_args()
    args.input_arguments = args
    args.pool = Pool(10)

    create_directory(args.output_dir)

    args.index = os.path.join(args.output_dir, args.index)
    args.collection = os.path.join(args.data_dir, args.collection)

    args.colbert, args.checkpoint = load_colbert(args)

    encode(args)
コード例 #4
0
def encode(args, number_of_subindexes_already_saved=0):
    # TODO: Create a metadata file; save `args.input_arguments` in there
    create_directory(args.index)

    args.bsize = args.bsize * torch.cuda.device_count()

    print("#> Starting with NUM_GPUs =", torch.cuda.device_count())
    print("#> Accordingly, setting total args.bsize =", args.bsize)

    colbert = args.colbert
    colbert.bert = nn.DataParallel(colbert.bert)
    colbert.linear = nn.DataParallel(colbert.linear)
    colbert = colbert.cuda()
    colbert.eval()

    print('\n\n\n')
    print("#> args.output_dir =", args.output_dir)
    print("#> number_of_subindexes_already_saved =",
          number_of_subindexes_already_saved)
    print('\n\n\n')

    super_batch_idx = 0
    super_batch, batch_indices = [], []

    with open(args.collection) as f:
        for idx, passage in enumerate(f):
            if len(super_batch) == SUPER_BATCH_SIZE:
                if super_batch_idx < number_of_subindexes_already_saved:
                    print("#> Skipping super_batch_idx =", super_batch_idx,
                          ".......")
                else:
                    process_batch(args, super_batch_idx, batch_indices,
                                  super_batch)

                print_message("Processed", str(idx), "passages so far...\n")

                super_batch_idx += 1
                super_batch, batch_indices = [], []

            pid, passage = passage.split('\t')
            super_batch.append(passage)
            batch_indices.append(idx)

            assert int(pid) == idx

    if len(super_batch):
        process_batch(args, super_batch_idx, batch_indices, super_batch)
        super_batch_idx += 1
コード例 #5
0
ファイル: train.py プロジェクト: kaist-irnlp/SparseColBERT
def main():
    random.seed(12345)
    torch.manual_seed(1)

    parser = ArgumentParser(
        description=
        "Training ColBERT with <query, positive passage, negative passage> triples."
    )

    parser.add_argument("--lr", dest="lr", default=3e-06, type=float)
    parser.add_argument("--maxsteps",
                        dest="maxsteps",
                        default=400000,
                        type=int)
    parser.add_argument("--bsize", dest="bsize", default=16, type=int)
    parser.add_argument("--accum", dest="accumsteps", default=4, type=int)

    parser.add_argument("--data_dir",
                        dest="data_dir",
                        default=DEFAULT_DATA_DIR)
    parser.add_argument("--triples",
                        dest="triples",
                        default="triples.train.small.tsv")
    parser.add_argument("--output_dir",
                        dest="output_dir",
                        default="outputs.train/")

    parser.add_argument("--similarity",
                        dest="similarity",
                        default="cosine",
                        choices=["cosine", "l2"])
    parser.add_argument("--dim", dest="dim", default=128, type=int)
    parser.add_argument("--query_maxlen",
                        dest="query_maxlen",
                        default=32,
                        type=int)
    parser.add_argument("--doc_maxlen",
                        dest="doc_maxlen",
                        default=180,
                        type=int)
    parser.add_argument("--use_dense", action="store_true")
    parser.add_argument("--base_model", default="bert-base-uncased", type=str)
    parser.add_argument("--n", default=4096, type=int)
    parser.add_argument("--k", default=0.005, type=float)
    parser.add_argument("--dont_normalize_sparse",
                        dest="normalize_sparse",
                        action="store_false")
    parser.add_argument("--use_nonneg", action="store_true")
    parser.add_argument("--use_ortho", action="store_true")

    # TODO: Add resume functionality
    # TODO: Save the configuration to the checkpoint.
    # TODO: Extract common parser arguments/behavior into a class.

    args = parser.parse_args()
    args.input_arguments = args

    create_directory(args.output_dir)

    assert args.bsize % args.accumsteps == 0, (
        (args.bsize, args.accumsteps),
        "The batch size must be divisible by the number of gradient accumulation steps.",
    )
    assert args.query_maxlen <= 512
    assert args.doc_maxlen <= 512

    args.triples = os.path.join(args.data_dir, args.triples)

    train(args)
コード例 #6
0
def main():
    random.seed(123456)

    parser = ArgumentParser(
        description=
        "Exhaustive (non-index-based) evaluation of re-ranking with ColBERT.")

    parser.add_argument("--checkpoint", dest="checkpoint", required=True)
    parser.add_argument("--topk", dest="topK", default="top1000.dev")
    parser.add_argument("--qrels", dest="qrels", default="qrels.dev.small.tsv")
    parser.add_argument("--shortcircuit",
                        dest="shortcircuit",
                        default=False,
                        action="store_true")

    parser.add_argument("--data_dir",
                        dest="data_dir",
                        default=DEFAULT_DATA_DIR)
    parser.add_argument("--output_dir",
                        dest="output_dir",
                        default="outputs.test/")

    parser.add_argument("--bsize", dest="bsize", default=128, type=int)
    parser.add_argument("--subsample", dest="subsample",
                        default=None)  # TODO: Add this
    parser.add_argument("--dense", action="store_true")

    # TODO: For the following four arguments, default should be None. If None, they should be loaded from checkpoint.
    parser.add_argument("--similarity",
                        dest="similarity",
                        default="cosine",
                        choices=["cosine", "l2"])
    parser.add_argument("--dim", dest="dim", default=128, type=int)
    parser.add_argument("--query_maxlen",
                        dest="query_maxlen",
                        default=32,
                        type=int)
    parser.add_argument("--doc_maxlen",
                        dest="doc_maxlen",
                        default=180,
                        type=int)
    parser.add_argument("--n", type=int, required=True)
    parser.add_argument("--k", type=float, required=True)
    parser.add_argument("--dont_normalize_sparse",
                        dest="normalize_sparse",
                        action="store_false")
    parser.add_argument("--use_nonneg", action="store_true")
    parser.add_argument("--use_ortho", action="store_true")

    args = parser.parse_args()
    args.input_arguments = args

    assert (not args.shortcircuit) or args.qrels, (
        "Short-circuiting (i.e., applying minimal computation to queries with no positives [in the re-ranked set]) "
        "can only be applied if qrels is provided.")

    args.pool = Pool(10)
    args.run_name = args.topK

    create_directory(args.output_dir)

    args.topK = os.path.join(args.data_dir, args.topK)

    if args.qrels:
        args.qrels = os.path.join(args.data_dir, args.qrels)

    args.checkpoint_path = args.checkpoint
    args.colbert, args.checkpoint = load_colbert(args)
    args.qrels = load_qrels(args.qrels)
    args.queries, args.topK_docs, args.topK_pids = load_topK(args.topK)

    evaluate_recall(args.qrels, args.queries, args.topK_pids)
    evaluate(args)
コード例 #7
0
def main():
    random.seed(123456)

    parser = ArgumentParser(
        description=
        "Exhaustive (non-index-based) evaluation of re-ranking with ColBERT.")

    parser.add_argument("--index", dest="index", required=True)
    parser.add_argument("--checkpoint", dest="checkpoint", required=True)
    parser.add_argument("--collection",
                        dest="collection",
                        default="collection.tsv")

    parser.add_argument("--data_dir",
                        dest="data_dir",
                        default=DEFAULT_DATA_DIR)
    parser.add_argument("--output_dir",
                        dest="output_dir",
                        default="outputs.index/")

    parser.add_argument("--bsize", dest="bsize", default=128, type=int)
    parser.add_argument("--bytes",
                        dest="bytes",
                        default=4,
                        choices=[2, 4],
                        type=int)
    parser.add_argument("--subsample", dest="subsample",
                        default=None)  # TODO: Add this

    # TODO: For the following four arguments, default should be None. If None, they should be loaded from checkpoint.
    parser.add_argument("--similarity",
                        dest="similarity",
                        default="cosine",
                        choices=["cosine", "l2"])
    parser.add_argument("--dim", dest="dim", default=128, type=int)
    parser.add_argument("--query_maxlen",
                        dest="query_maxlen",
                        default=32,
                        type=int)
    parser.add_argument("--doc_maxlen",
                        dest="doc_maxlen",
                        default=180,
                        type=int)
    parser.add_argument("--dense", action="store_true")
    parser.add_argument("--n", default=4096, type=int)
    parser.add_argument("--k", default=0.005, type=float)
    parser.add_argument("--dont_normalize_sparse",
                        dest="normalize_sparse",
                        action="store_false")

    # TODO: Add resume functionality

    args = parser.parse_args()
    args.input_arguments = args
    args.pool = Pool(4)

    create_directory(args.output_dir)

    args.index = os.path.join(args.output_dir, args.index)
    args.collection = os.path.join(args.data_dir, args.collection)

    args.colbert, args.checkpoint = load_colbert(args)

    encode(args)
コード例 #8
0
ファイル: test.py プロジェクト: DI4IR/SIGIR2021
def main():
    random.seed(123456)

    parser = ArgumentParser(
        description=
        'Exhaustive (non-index-based) evaluation of re-ranking with ColBERT.')

    parser.add_argument('--checkpoint', dest='checkpoint', required=True)
    parser.add_argument('--topk', dest='topK', required=True)
    parser.add_argument('--qrels', dest='qrels', default=None)
    parser.add_argument('--shortcircuit',
                        dest='shortcircuit',
                        default=False,
                        action='store_true')

    parser.add_argument('--data_dir',
                        dest='data_dir',
                        default=DEFAULT_DATA_DIR)
    parser.add_argument('--output_dir',
                        dest='output_dir',
                        default='outputs.test/')

    parser.add_argument('--bsize', dest='bsize', default=128, type=int)
    parser.add_argument('--subsample', dest='subsample',
                        default=None)  # TODO: Add this

    # TODO: For the following four arguments, default should be None. If None, they should be loaded from checkpoint.
    parser.add_argument('--similarity',
                        dest='similarity',
                        default='cosine',
                        choices=['cosine', 'l2'])
    parser.add_argument('--dim', dest='dim', default=128, type=int)
    parser.add_argument('--query_maxlen',
                        dest='query_maxlen',
                        default=32,
                        type=int)
    parser.add_argument('--doc_maxlen',
                        dest='doc_maxlen',
                        default=180,
                        type=int)

    args = parser.parse_args()
    args.input_arguments = args

    assert (not args.shortcircuit) or args.qrels, \
        "Short-circuiting (i.e., applying minimal computation to queries with no positives [in the re-ranked set]) " \
        "can only be applied if qrels is provided."

    args.pool = Pool(10)
    args.run_name = args.topK

    create_directory(args.output_dir)

    args.topK = os.path.join(args.data_dir, args.topK)

    if args.qrels:
        args.qrels = os.path.join(args.data_dir, args.qrels)

    args.colbert, args.checkpoint = load_colbert(args)
    args.qrels = load_qrels(args.qrels)
    args.queries, args.topK_docs, args.topK_pids = load_topK(args.topK)

    evaluate_recall(args.qrels, args.queries, args.topK_pids)
    evaluate(args)
コード例 #9
0
def test_create_directory_absolute(tmp_path):
    if not (tmp_path / 'tempfile').exists():
        utils.create_directory(tmp_path / 'tempfile')
        assert (tmp_path / 'tempfile').exists()
    else:
        assert False