def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--jsonl-path",
                        type=str,
                        default="data/retriever/nq-train.jsonl")
    parser.add_argument("--tfrecord-text-path",
                        type=str,
                        default="data/retriever/V3/N5000-TEXT")
    parser.add_argument("--tfrecord-int-path",
                        type=str,
                        default="data/retriever/V3/N5000-INT")
    parser.add_argument(
        "--ctx-source-path",
        type=str,
        default="gs://openqa-dpr/data/wikipedia_split/shards-42031")
    parser.add_argument("--ctx-tokenized-path",
                        type=str,
                        default="data/wikipedia_split/shards-42031-tfrecord")
    parser.add_argument("--shard-size", type=int, default=42031)
    parser.add_argument("--max-context-length",
                        type=int,
                        default=const.MAX_CONTEXT_LENGTH)
    parser.add_argument("--max-query-length",
                        type=int,
                        default=const.MAX_QUERY_LENGTH)
    parser.add_argument("--records-per-file", type=int, default=5000)
    parser.add_argument("--shuffle", type=eval, default=const.SHUFFLE)
    parser.add_argument("--shuffle-seed", type=int, default=const.SHUFFLE_SEED)
    parser.add_argument("--pretrained-model",
                        type=str,
                        default=const.PRETRAINED_MODEL)
    parser.add_argument("--qas-path", type=str, default=const.QUERY_PATH)
    parser.add_argument("--qas-tfrecord-path",
                        type=str,
                        default="data/qas/nq-test.tfrecord")

    args = parser.parse_args()
    # TODO: build tfrecord dataset
    # build_tfrecord_text_data_from_jsonl(
    #     input_path="data/retriever/V2/vicovid-train.jsonl",
    #     out_dir="data/retriever/V2/TEXT",
    #     records_per_file=5000,
    #     num_hard_negatives=1
    # )

    build_tfrecord_int_data_from_tfrecord_text_data(
        input_path='data/retriever/V2/TEXT',
        out_dir="data/retriever/V2/INT",
        tokenizer=get_tokenizer(model_name='NlpHUST/vibert4news-base-cased',
                                prefix='pretrained'),
        shuffle=True,
        shuffle_seed=123,
        num_hard_negatives=1)
def build_tfrecord_tokenized_data_for_qas_ver2(pretrained_model: str,
                                               qas_path: str, out_dir: str,
                                               prefix):
    tokenizer = get_tokenizer(model_name=pretrained_model, prefix=prefix)
    text_dataset = tf.data.TextLineDataset(qas_path)

    def _transform():
        for element in text_dataset:
            question, answers = element.numpy().decode().split("\t")
            question_ids = tokenizer.encode(question)
            question_ids = tf.convert_to_tensor(question_ids)
            yield tf.sparse.from_dense(question_ids)

    sparse_dataset = tf.data.Dataset.from_generator(
        _transform,
        output_signature=tf.SparseTensorSpec(shape=[None], dtype=tf.int32))

    def _serialize(element):
        element_serialized = tf.io.serialize_tensor(
            tf.io.serialize_sparse(element))
        features = {
            'question_serialized':
            tf.train.Feature(bytes_list=tf.train.BytesList(
                value=[element_serialized.numpy()]))
        }
        example = tf.train.Example(features=tf.train.Features(
            feature=features))
        return example.SerializeToString()

    def _generator():
        count = 0
        for element in sparse_dataset:
            yield _serialize(element)
            count += 1
            print("Count: {}".format(count))

    dataset = tf.data.Dataset.from_generator(_generator,
                                             output_signature=tf.TensorSpec(
                                                 [], tf.string))

    file_name = os.path.basename(qas_path).split(".")[0]
    writer = tf.data.experimental.TFRecordWriter(
        os.path.join(out_dir, "{}-ver2.tfrecord".format(file_name)))
    writer.write(dataset)
def validate(
    reader,
    strategy,
    dataset,
    params=None
):
    print("Validating...")

    if params is not None:
        global args
        args = params
    
    def dist_forward_step(input_ids):
        print("This function is tracing")

        def step_fn(input_ids):
            attention_mask = tf.cast(input_ids > 0, dtype=tf.int32)

            start_logits, end_logits = reader(
                input_ids=input_ids,
                attention_mask=attention_mask,
                training=False
            )

            return start_logits, end_logits

        per_replica_logits = strategy.run(step_fn, args=(input_ids,))
        return per_replica_logits
    
    if not args.disable_tf_function:
        dist_forward_step = tf.function(dist_forward_step)

    def value_fn_template(ctx, indices, tensors):
        start, end = indices[ctx.replica_id_in_sync_group]
        return tensors[start : end]

    processes = ProcessPool(processes=os.cpu_count())
    tokenizer = get_tokenizer(model_name=args.pretrained_model, prefix=args.prefix)
    get_best_span_partial = partial(get_best_span, max_answer_length=args.max_answer_length, tokenizer=tokenizer)
        
    iterator = iter(dataset)
    em_hits = []
    match_stats = []
    for element in tqdm(iterator):
        answers_serialized = element['answers']
        question = element['question']
        passage_offset = element['passage_offset']
        input_ids = element['input_ids']

        if strategy.num_replicas_in_sync > 1:
            reduced_input_ids = tf.concat(input_ids.values, axis=0)
        else:
            reduced_input_ids = input_ids
        
        global_batch_size = reduced_input_ids.shape[0]
        
        # forward pass
        if global_batch_size < args.batch_size * strategy.num_replicas_in_sync:
            base_replica_batch_size = args.batch_size
            flag = False
        
            while True:
                spread, global_batch_size, base_replica_batch_size = spread_samples_equally(
                    global_batch_size=global_batch_size,
                    num_replicas=strategy.num_replicas_in_sync,
                    base_replica_batch_size=base_replica_batch_size,
                    init_batch_size=args.batch_size
                )

                if len(spread) > 1:
                    indices = []
                    idx = 0
                    for num in spread:
                        indices.append((idx, idx + num))
                        idx += num
                    
                    value_fn = partial(value_fn_template, indices=indices, tensors=reduced_input_ids)
                    reduced_input_ids = reduced_input_ids[base_replica_batch_size * strategy.num_replicas_in_sync:]
                    dist_input_ids = strategy.experimental_distribute_values_from_function(value_fn)
                    start_logits, end_logits = dist_forward_step(dist_input_ids)
                    
                    if not flag:
                        if strategy.num_replicas_in_sync > 1:
                            global_start_logits = tf.concat(start_logits.values, axis=0)
                            global_end_logits = tf.concat(end_logits.values, axis=0)
                        else:
                            global_start_logits = start_logits
                            global_end_logits = end_logits
                        flag = True

                    else:
                        if strategy.num_replicas_in_sync > 1:
                            global_start_logits = tf.concat([global_start_logits, *start_logits.values], axis=0)
                            global_end_logits = tf.concat([global_end_logits, *start_logits.values], axis=0)
                        else:
                            global_start_logits = tf.concat([global_start_logits, start_logits], axis=0)
                            global_end_logits = tf.concat([global_end_logits, end_logits], axis=0)

                    if global_batch_size == 0:
                        break

                else:
                    start_logits, end_logits = dist_forward_step(reduced_input_ids)
                    if not flag:
                        if strategy.num_replicas_in_sync > 1:
                            global_start_logits = start_logits.values[0]
                            global_end_logits = end_logits.values[0]
                        else:
                            global_start_logits = start_logits
                            global_end_logits = end_logits
                        flag = True
                    else:
                        if strategy.num_replicas_in_sync > 1:
                            global_start_logits = tf.concat([global_start_logits, start_logits.values[0]], axis=0)
                            global_end_logits = tf.concat([global_end_logits, end_logits.values[0]], axis=0)
                        else:
                            global_start_logits = tf.concat([global_start_logits, start_logits], axis=0)
                            global_end_logits = tf.concat([global_end_logits, end_logits], axis=0)
                    break
 
        else:
            start_logits, end_logits = dist_forward_step(input_ids)
            if strategy.num_replicas_in_sync > 1:
                global_start_logits = tf.concat(start_logits.values, axis=0)
                global_end_logits = tf.concat(end_logits.values, axis=0)
            else:
                global_start_logits = start_logits
                global_end_logits = end_logits
        
        if strategy.num_replicas_in_sync > 1:
            input_ids = tf.concat(input_ids.values, axis=0)
            passage_offset = tf.concat(passage_offset.values, axis=0)
            answers_serialized = tf.concat(answers_serialized.values, axis=0)
            question = tf.concat(question.values, axis=0)
        
        question = question.numpy().tolist()
        question = [q.decode() for q in question]
        
        sentence_ids = tf.RaggedTensor.from_tensor(input_ids, padding=tokenizer.pad_token_id)
        sentence_ids = sentence_ids.to_list()
        ctx_ids = [ids[offset:] for ids, offset in zip(sentence_ids, passage_offset)]

        start_logits = global_start_logits.numpy().tolist()
        start_logits = [logits[offset : offset + len(ctx)] for logits, offset, ctx in zip(start_logits, passage_offset, ctx_ids)]
        end_logits = global_end_logits.numpy().tolist()
        end_logits = [logits[offset : offset + len(ctx)] for logits, offset, ctx in zip(end_logits, passage_offset, ctx_ids)]
        
        best_spans = processes.starmap(get_best_span_partial, zip(start_logits, end_logits, ctx_ids))

        answers = []
        for ans in answers_serialized:
            ans_sparse = tf.io.parse_tensor(ans, out_type=tf.string)
            ans_values = tf.io.parse_tensor(ans_sparse[1], out_type=tf.string)
            ans_values = [answer.numpy().decode() for answer in ans_values]
            answers.append(ans_values)

        hits = processes.starmap(compare_spans, zip(answers, best_spans))
        passages = [tokenizer.decode(ids) for ids in ctx_ids]

        stats = [
            {
                "question": q,
                "answers": ans,
                "passage": psg,
                "predicted": span,
                "hit": hit
            }
            for q, ans, span, psg, hit in zip(question, answers, best_spans, passages, hits)
        ]
        match_stats.extend(stats)
        em_hits.extend(hits)

    print("done")
    print("-----------------------------------------------------------")
    return em_hits, match_stats
示例#4
0
    kwargs = {
        'ctx_source_path':
        'data/wikipedia_split/vi_covid_subset_ctx_source.tsv',
        'pretrained_model': 'NlpHUST/vibert4news-base-cased',
        'use_pooler': False,
        'max_query_length': 64
    }
    args = Namespace(**kwargs)

    index_path = "indexer/vicovid_inbatch_batch8_query64"
    indexer = DenseFlatIndexer(buffer_size=50000)
    index_files = glob.glob("{}/*".format(index_path))
    print("Deserializing indexer from disk... ")
    indexer.deserialize(index_path)

    all_docs = load_ctx_sources(args.ctx_source_path)

    tokenizer = get_tokenizer(model_name='NlpHUST/vibert4news-base-cased',
                              prefix='pretrained')
    question_encoder = load_checkpoint(args)

    tokenizer = get_tokenizer(model_name='NlpHUST/vibert4news-base-cased')

    retrieve_func = partial(retrieve,
                            indexer=indexer,
                            question_encoder=question_encoder,
                            tokenizer=tokenizer,
                            top_docs=10,
                            all_docs=all_docs)

    print("done")
def build_tfrecord_tokenized_data_for_ctx_sources(
        pretrained_model: str,
        ctx_source_path: str,
        out_dir: str,
        max_context_length: int = 256,
        shard_size: int = 42031,
        prefix=None):
    tokenizer = get_tokenizer(model_name=pretrained_model, prefix=prefix)
    ctx_source_files = glob.glob("{}/*.tsv".format(ctx_source_path))
    ctx_source_files.sort()
    text_dataset = None
    for ctx_source in ctx_source_files:
        df = pd.read_csv(ctx_source,
                         sep="\t",
                         header=None,
                         names=['id', 'text', 'title'])
        if text_dataset is None:
            text_dataset = df
        else:
            text_dataset = pd.concat([text_dataset, df],
                                     axis='index',
                                     ignore_index=True)

    def _transform():
        count = 0
        for _, element in text_dataset.iterrows():
            id, text, title = element.id, element.text, element.title
            id = str(id)
            text = str(text)
            title = str(title)
            passage_id = "wiki:{}".format(id)

            text_tokens = tokenizer.tokenize(text)
            title_tokens = tokenizer.tokenize(title)
            sent_tokens = [tokenizer.cls_token] + title_tokens \
                          + [tokenizer.sep_token] + text_tokens + [tokenizer.sep_token]

            if len(sent_tokens) < max_context_length:
                sent_tokens += [tokenizer.pad_token
                                ] * (max_context_length - len(sent_tokens))

            sent_tokens = sent_tokens[:max_context_length]
            sent_tokens[-1] = tokenizer.sep_token

            context_ids = tokenizer.convert_tokens_to_ids(sent_tokens)
            context_ids = tf.convert_to_tensor(context_ids, dtype=tf.int32)

            count += 1
            print("Count: {}".format(count))

            yield {
                'context_ids': context_ids,
                'passage_id': tf.constant(passage_id)
            }

    dataset = tf.data.Dataset.from_generator(
        _transform,
        output_signature={
            'context_ids': tf.TensorSpec([max_context_length], tf.int32),
            'passage_id': tf.TensorSpec([], tf.string)
        })

    def _serialize(context_ids, passage_id):
        features = {
            'context_ids':
            tf.train.Feature(int64_list=tf.train.Int64List(value=context_ids)),
            'passage_id':
            tf.train.Feature(bytes_list=tf.train.BytesList(
                value=[passage_id.numpy()]))
        }

        example = tf.train.Example(features=tf.train.Features(
            feature=features))
        return example.SerializeToString()

    dataset = dataset.map(lambda x: tf.py_function(
        _serialize, inp=[x['context_ids'], x['passage_id']], Tout=tf.string),
                          num_parallel_calls=tf.data.AUTOTUNE,
                          deterministic=True)

    dataset = dataset.window(shard_size)

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    idx = 0
    for window in dataset:
        writer = tf.data.experimental.TFRecordWriter(
            os.path.join(out_dir, "psgs_subset_{:02d}.tfrecord".format(idx)))
        writer.write(window)
        idx += 1
def validate(dataset, strategy, ranker, reader, params=None):
    if params is not None:
        global args
        args = params

    def dist_forward_ranker_step(input_ids):
        print("This function is tracing")

        def step_fn(input_ids):
            input_ids = tf.reshape(input_ids, [-1, args.max_sequence_length])
            attention_mask = tf.cast(input_ids > 0, dtype=tf.int32)

            rank_logits = ranker(input_ids=input_ids,
                                 attention_mask=attention_mask,
                                 training=False)

            return tf.reshape(rank_logits, [args.batch_size, -1])

        per_replica_logits = strategy.run(step_fn, args=(input_ids, ))
        return per_replica_logits

    def dist_forward_reader_step(input_ids):
        print("This function is tracing")

        def step_fn(input_ids):
            attention_mask = tf.cast(input_ids > 0, dtype=tf.int32)

            start_logits, end_logits = reader(input_ids=input_ids,
                                              attention_mask=attention_mask,
                                              training=False)

            return start_logits, end_logits

        per_replica_results = strategy.run(step_fn, args=(input_ids, ))
        return per_replica_results

    if not args.disable_tf_function:
        dist_forward_ranker_step = tf.function(dist_forward_ranker_step)
        dist_forward_reader_step = tf.function(dist_forward_reader_step)

    def value_fn_template(ctx, pool_tensors):
        return pool_tensors[ctx.replica_id_in_sync_group]

    processes = ProcessPool(processes=os.cpu_count())
    tokenizer = get_tokenizer(model_name=args.pretrained_model,
                              prefix=args.prefix)

    get_best_span_partial = partial(get_best_span,
                                    max_answer_length=args.max_answer_length,
                                    tokenizer=tokenizer)

    iterator = iter(dataset)
    em_hits = []
    match_stats = []
    for element in tqdm(iterator):
        answers_serialized = element['answers']
        question = element['question']
        input_ids = element[
            'passages/sequence_ids']  # bsz x num_passages x max_sequence_length
        passage_offsets = element[
            'passages/passage_offset']  # bsz x num_passages

        reduced_input_ids = tf.concat(input_ids.values, axis=0)
        per_replica_passage_offsets = strategy.experimental_local_results(
            passage_offsets)

        global_batch_size = reduced_input_ids.shape[0]
        if global_batch_size < args.batch_size * strategy.num_replicas_in_sync:
            # TODO: add code in case batch is not divisible
            aggregated_input_ids = tf.concat(input_ids.values, axis=0)
            padded_size = args.batch_size * strategy.num_replicas_in_sync - global_batch_size
            padded_input_ids = tf.zeros(
                [padded_size, args.max_passages, args.max_sequence_length],
                dtype=tf.int32)
            input_ids = tf.concat([aggregated_input_ids, padded_input_ids],
                                  axis=0)
            pool_input_ids = tf.split(
                input_ids,
                num_or_size_splits=strategy.num_replicas_in_sync,
                axis=0)
            value_fn_for_input_ids = partial(value_fn_template,
                                             pool_tensors=pool_input_ids)
            input_ids = strategy.experimental_distribute_values_from_function(
                value_fn_for_input_ids)

            aggregated_per_replica_passage_offsets = tf.concat(
                per_replica_passage_offsets, axis=0)
            lack_size = args.batch_size * strategy.num_replicas_in_sync - aggregated_per_replica_passage_offsets.shape[
                0]
            padded_per_replica_passage_offsets = tf.zeros(
                [lack_size, args.max_passages], dtype=tf.int32)
            per_replica_passage_offsets = tf.concat([
                aggregated_per_replica_passage_offsets,
                padded_per_replica_passage_offsets
            ],
                                                    axis=0)
            per_replica_passage_offsets = tf.split(
                per_replica_passage_offsets,
                num_or_size_splits=strategy.num_replicas_in_sync)

        rank_logits = dist_forward_ranker_step(input_ids)
        rank_logits = strategy.experimental_local_results(rank_logits)
        selected_passage_idxs = [
            tf.cast(tf.argmax(logits, axis=-1), dtype=tf.int32)
            for logits in rank_logits
        ]  # num_replicas x batch_size

        selected_passage_offsets = []
        per_replica_input_ids = strategy.experimental_local_results(
            input_ids
        )  # num_replicas x batch_sizse x max_passages x max_sequence_length
        selected_input_ids = []

        for sequence_ids, psg_offsets, passage_idxs in zip(
                per_replica_input_ids, per_replica_passage_offsets,
                selected_passage_idxs):
            range_idxs = tf.range(sequence_ids.shape[0], dtype=tf.int32)
            indices = tf.concat([
                tf.expand_dims(range_idxs, axis=1),
                tf.expand_dims(passage_idxs, axis=1)
            ],
                                axis=1)
            selected_passage_offsets.append(tf.gather_nd(psg_offsets, indices))
            selected_input_ids.append(tf.gather_nd(sequence_ids, indices))

        value_fn = partial(value_fn_template, pool_tensors=selected_input_ids)
        dist_selected_input_ids = strategy.experimental_distribute_values_from_function(
            value_fn)

        start_logits, end_logits = dist_forward_reader_step(
            input_ids=dist_selected_input_ids)
        sentence_ids = tf.concat(dist_selected_input_ids.values, axis=0)
        sentence_ids = tf.RaggedTensor.from_tensor(
            sentence_ids, padding=tokenizer.pad_token_id)
        sentence_ids = sentence_ids.to_list()
        sentence_ids = sentence_ids[:global_batch_size]
        selected_passage_offsets = tf.concat(selected_passage_offsets, axis=0)
        selected_passage_offsets = selected_passage_offsets[:global_batch_size]
        ctx_ids = [
            ids[offset:]
            for ids, offset in zip(sentence_ids, selected_passage_offsets)
        ]

        start_logits = tf.concat(start_logits.values, axis=0)
        start_logits = start_logits.numpy().tolist()
        start_logits = start_logits[:global_batch_size]
        start_logits = [
            logits[offset:offset + len(ctx)] for logits, offset, ctx in zip(
                start_logits, selected_passage_offsets, ctx_ids)
        ]
        end_logits = tf.concat(end_logits.values, axis=0)
        end_logits = end_logits.numpy().tolist()
        end_logits = end_logits[:global_batch_size]
        end_logits = [
            logits[offset:offset + len(ctx)] for logits, offset, ctx in zip(
                end_logits, selected_passage_offsets, ctx_ids)
        ]

        best_spans = processes.starmap(get_best_span_partial,
                                       zip(start_logits, end_logits, ctx_ids))

        answers_serialized = tf.concat(answers_serialized.values, axis=0)
        question = tf.concat(question.values, axis=0)

        answers = []
        for ans in answers_serialized:
            ans_sparse = tf.io.parse_tensor(ans, out_type=tf.string)
            ans_values = tf.io.parse_tensor(ans_sparse[1], out_type=tf.string)
            ans_values = [answer.numpy().decode() for answer in ans_values]
            answers.append(ans_values)

        question = question.numpy().tolist()
        question = [q.decode() for q in question]

        hits = processes.starmap(compare_spans, zip(answers, best_spans))
        passages = [tokenizer.decode(ids) for ids in ctx_ids]

        selected_passage_idxs = tf.concat(selected_passage_idxs, axis=0)
        selected_passage_idxs = selected_passage_idxs.numpy().tolist()

        stats = [{
            "question": q,
            "answers": ans,
            "passage": psg,
            "predicted": span,
            "retriever_rank": idx + 1,
            "hit": hit
        }
                 for q, ans, span, idx, psg, hit in zip(
                     question, answers, best_spans, selected_passage_idxs,
                     passages, hits)]
        match_stats.extend(stats)
        em_hits.extend(hits)

    print("done")
    print("-----------------------------------------------------------")
    return em_hits, match_stats
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--query-path",
                        type=str,
                        default=const.QUERY_PATH,
                        help="Path to the queries used to test retriever")
    parser.add_argument("--ctx-source-path",
                        type=str,
                        default=const.CTX_SOURCE_PATH,
                        help="Path to the file containg all passages")
    parser.add_argument("--checkpoint-path",
                        type=str,
                        default=const.CHECKPOINT_PATH,
                        help="Path to the checkpointed model")
    parser.add_argument(
        "--top-k",
        type=int,
        default=const.TOP_K,
        help="Number of documents that expects to be returned by retriever")
    parser.add_argument("--batch-size",
                        type=int,
                        default=16,
                        help="Batch size when embedding questions")
    parser.add_argument("--index-path",
                        type=str,
                        default=const.INDEX_PATH,
                        help="Path to indexed database")
    parser.add_argument("--pretrained-model",
                        type=str,
                        default=const.PRETRAINED_MODEL)
    parser.add_argument("--reader-data-path",
                        type=str,
                        default=const.READER_DATA_PATH)
    parser.add_argument("--result-path", type=str, default=const.RESULT_PATH)
    parser.add_argument("--embeddings-path",
                        type=str,
                        default=const.EMBEDDINGS_DIR)
    parser.add_argument("--force-create-index", type=eval, default=False)
    parser.add_argument("--qas-tfrecord-path",
                        type=str,
                        default=const.QAS_TFRECORD_PATH)
    parser.add_argument("--max-query-length",
                        type=int,
                        default=const.MAX_QUERY_LENGTH)
    parser.add_argument("--disable-tf-function", type=eval, default=False)
    parser.add_argument("--tpu", type=str, default=const.TPU_NAME)
    parser.add_argument("--use-pooler", type=eval, default=True)
    parser.add_argument("--prefix", type=str, default='pretrained')

    global args
    args = parser.parse_args()
    model_type = os.path.basename(args.checkpoint_path)
    embeddings_path = os.path.join(args.embeddings_path, "shards-42031",
                                   model_type)
    args_dict = {**args.__dict__, "embeddings_path": embeddings_path}

    configs = ["{}: {}".format(k, v) for k, v in args_dict.items()]
    configs_string = "\t" + "\n\t".join(configs) + "\n"
    print("************************* Configurations *************************")
    print(configs_string)
    print(
        "----------------------------------------------------------------------------------------------------------------------"
    )

    config_path = "configs/{}/{}/config.yml".format(
        __file__.rstrip(".py"),
        datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
    write_config(config_path, args_dict)

    try:  # detect TPUs
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=args.tpu)  # TPU detection
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.TPUStrategy(resolver)
    except Exception:  # detect GPUs
        devices = tf.config.list_physical_devices("GPU")
        # [tf.config.experimental.set_memory_growth(device, True) for device in devices]
        if devices:
            strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
        else:
            strategy = tf.distribute.get_strategy()

    index_path = os.path.join(args.index_path, model_type)
    if not os.path.exists(index_path):
        os.makedirs(index_path)
    indexer = create_or_retrieve_indexer(index_path=index_path,
                                         embeddings_path=embeddings_path)
    # exit(0) # only create index
    question_encoder = load_checkpoint(checkpoint_path=args.checkpoint_path,
                                       strategy=strategy)
    questions, answers = load_qas_test_data()
    tokenizer = get_tokenizer(model_name=args.pretrained_model,
                              prefix=args.prefix)
    dataset = prepare_dataset(args.qas_tfrecord_path,
                              strategy=strategy,
                              tokenizer=tokenizer,
                              max_query_length=args.max_query_length)
    question_embeddings = generate_embeddings(
        question_encoder=question_encoder, dataset=dataset, strategy=strategy)
    top_ids_and_scores = search_knn(indexer=indexer,
                                    question_embeddings=question_embeddings)
    all_docs = load_ctx_sources()

    print("Validating... ")
    top_k_hits_path = os.path.join(args.result_path, model_type,
                                   "top_k_hits.txt")
    if not os.path.exists(os.path.dirname(top_k_hits_path)):
        os.makedirs(os.path.dirname(top_k_hits_path))

    start_time = time.perf_counter()
    questions_doc_hits = validate(top_ids_and_scores=top_ids_and_scores,
                                  answers=answers,
                                  ctx_sources=all_docs,
                                  top_k_hits_path=top_k_hits_path)
    print("done in {}s !".format(time.perf_counter() - start_time))
    print(
        "----------------------------------------------------------------------------------------------------------------------"
    )

    print("Generating reader data... ")
    reader_data_path = os.path.join(args.reader_data_path, model_type,
                                    "reader_data.json")
    if not os.path.exists(os.path.dirname(reader_data_path)):
        os.makedirs(os.path.dirname(reader_data_path))
    start_time = time.perf_counter()
    save_results(questions=questions,
                 answers=answers,
                 all_docs=all_docs,
                 top_passages_and_scores=top_ids_and_scores,
                 per_question_hits=questions_doc_hits,
                 out_file=reader_data_path)

    print("done in {}s !".format(time.perf_counter() - start_time))
    print(
        "----------------------------------------------------------------------------------------------------------------------"
    )
示例#8
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--train-data-size", type=int, default=const.TRAIN_DATA_SIZE)
    parser.add_argument("--data-path", type=str, default=const.DATA_PATH, help="Path to the `.tfrecord` data. Data in this file is already preprocessed into tensor format")
    parser.add_argument("--max-context-length", type=int, default=const.MAX_CONTEXT_LENGTH, help="Maximum length of a document")
    parser.add_argument("--max-query-length", type=int, default=const.MAX_QUERY_LENGTH, help="Maximum length of a question")
    parser.add_argument("--batch-size", type=int, default=const.BATCH_SIZE, help="Batch size on each compute device")
    parser.add_argument("--epochs", type=int, default=const.EPOCHS)
    parser.add_argument("--learning-rate", type=float, default=const.LEARNING_RATE)
    parser.add_argument("--warmup-steps", type=int, default=const.WARMUP_STEPS)
    parser.add_argument("--adam-eps", type=float, default=const.ADAM_EPS)
    parser.add_argument("--adam-betas", type=eval, default=const.ADAM_BETAS)
    parser.add_argument("--weight-decay", type=float, default=const.WEIGHT_DECAY)
    parser.add_argument("--max-grad-norm", type=float, default=const.MAX_GRAD_NORM)
    parser.add_argument("--shuffle", type=eval, default=const.SHUFFLE)
    parser.add_argument("--seed", type=int, default=const.SHUFFLE_SEED)
    parser.add_argument("--checkpoint-path", type=str, default=const.CHECKPOINT_PATH)
    parser.add_argument("--ctx-encoder-trainable", type=eval, default=const.CTX_ENCODER_TRAINABLE, help="Whether the context encoder's weights are trainable")
    parser.add_argument("--question-encoder-trainable", type=eval, default=const.QUESTION_ENCODER_TRAINABLE, help="Whether the question encoder's weights are trainable")
    parser.add_argument("--tpu", type=str, default=const.TPU_NAME)
    parser.add_argument("--loss-fn", type=str, choices=['inbatch', 'threelevel', 'twolevel', 'hardnegvsneg', 'hardnegvsnegsoftmax', 'threelevelsoftmax'], default='threelevel')
    parser.add_argument("--use-pooler", type=eval, default=True)
    parser.add_argument("--load-optimizer", type=eval, default=True)
    parser.add_argument("--tokenizer", type=str, default="bert-base-uncased")
    parser.add_argument("--question-pretrained-model", type=str, default='bert-base-uncased')
    parser.add_argument("--context-pretrained-model", type=str, default='bigbird')
    parser.add_argument("--within-size", type=int, default=8)
    parser.add_argument("--prefix", type=str, default=None)

    global args
    args = parser.parse_args()
    args_dict = args.__dict__

    configs = ["{}: {}".format(k, v) for k, v in args_dict.items()]
    configs_string = "\t" + "\n\t".join(configs) + "\n"
    print("************************* Configurations *************************")
    print(configs_string)
    print("----------------------------------------------------------------------------------------------------------------------")

    file_name = os.path.basename(__file__)
    config_path = "configs/{}/{}/config.yml".format(file_name.rstrip(".py"), datetime.now().strftime("%Y-%m-%d %H-%M-%S"))
    write_config(config_path, args_dict)

    epochs = args.epochs

    global strategy
    try: # detect TPUs
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=args.tpu) # TPU detection
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.TPUStrategy(resolver)
    except Exception: # detect GPUs
        devices = tf.config.list_physical_devices("GPU")
        # [tf.config.experimental.set_memory_growth(device, True) for device in devices]
        if devices:
            strategy = tf.distribute.MirroredStrategy()
        else:
            strategy = tf.distribute.get_strategy()

    tf.random.set_seed(args.seed)

    tokenizer = get_tokenizer(model_name=args.tokenizer, prefix=args.prefix)

    """Data pipeline
    1. Load retriever data (in `.tfrecord` format, stored serialized `tf.int32` tensor)
    2. Padding sequence to the same length
    3. Shuffle: You should shuffle before batch to guarantee that each data sample can be batched with different data samples in different epochs
    4. Repeating data: repeat to produce indefininte data stream
    5. Batching dataset
    6. Prefetching dataset (to speed up training) 
    """
    dataset = biencoder_manipulator.load_retriever_tfrecord_int_data(
        input_path=args.data_path,
        shuffle=args.shuffle,
        shuffle_seed=args.seed
    )
    dataset = biencoder_manipulator.pad(
        dataset, 
        sep_token_id=tokenizer.sep_token_id,
        max_context_length=args.max_context_length, 
        max_query_length=args.max_query_length
    )
    if args.loss_fn == 'inbatch':
        dataset = dataset.map(
            lambda x: {
                'question': x['question'],
                'contexts': x['contexts'][:2]
            },
            num_parallel_calls=tf.data.AUTOTUNE,
        )
    else:
        dataset = dataset.map(
            lambda x: {
                'question': x['question'],
                'contexts': x['contexts'][:args.within_size]
            }
        )
    dataset = dataset.shuffle(buffer_size=60000)
    dataset = dataset.repeat()
    dataset = dataset.batch(args.batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)

    """
    Distribute the dataset
    """
    dist_dataset = strategy.distribute_datasets_from_function(
        lambda _: dataset
    )
    iterator = iter(dist_dataset)

    """
    Set up for distributed training
    """
    steps_per_epoch = args.train_data_size // (args.batch_size * strategy.num_replicas_in_sync)
    global optimizer
    global loss_fn
    global retriever
    with strategy.scope():
        # Instantiate question encoder
        question_encoder = get_encoder(
            model_name=args.question_pretrained_model,
            args=args,
            trainable=args.question_encoder_trainable,
            prefix=args.prefix
        )

        # Instantiate context encoder
        context_encoder = get_encoder(
            model_name=args.context_pretrained_model,
            args=args,
            trainable=args.ctx_encoder_trainable,
            prefix=args.prefix
        )

        retriever = models.BiEncoder(
            question_model=question_encoder,
            ctx_model=context_encoder,
            use_pooler=args.use_pooler
        )

        # Instantiate the optimizer
        optimizer = optimizers.get_adamw(
            steps_per_epoch=steps_per_epoch,
            warmup_steps=args.warmup_steps,
            epochs=args.epochs,
            learning_rate=args.learning_rate,
            eps=args.adam_eps,
            beta_1=args.adam_betas[0],
            beta_2=args.adam_betas[1],
            weight_decay=args.weight_decay,
        )

        # Define loss function
        if args.loss_fn == 'threelevel':
            loss_fn = biencoder.ThreeLevelDPRLoss(batch_size=args.batch_size, within_size=args.within_size)
        elif args.loss_fn == 'twolevel':
            loss_fn = biencoder.TwoLevelDPRLoss(batch_size=args.batch_size, within_size=args.within_size)
        elif args.loss_fn == "hardnegvsneg":
            loss_fn = biencoder.HardNegVsNegDPRLoss(batch_size=args.batch_size, within_size=args.within_size)
        elif args.loss_fn == 'hardnegvsnegsoftmax':
            loss_fn = biencoder.HardNegVsNegSoftMaxDPRLoss(batch_size=args.batch_size, within_size=args.within_size)
        elif args.loss_fn == 'threelevelsoftmax':
            loss_fn = biencoder.ThreeLevelSoftMaxDPRLoss(batch_size=args.batch_size, within_size=args.within_size)
        else:
            loss_fn = biencoder.InBatchDPRLoss(batch_size=args.batch_size)


    """
    Distributed train step
    """
    dist_train_step = get_dist_train_step(
        model_name=args.context_pretrained_model
    )


    """
    Configure checkpoint
    """
    with strategy.scope():
        checkpoint_path = args.checkpoint_path
        ckpt = tf.train.Checkpoint(
            model=retriever,
            current_epoch=tf.Variable(0)
        )
        if not args.load_optimizer:
            tmp_optimizer = copy.deepcopy(optimizer)
            ckpt.optimizer = tmp_optimizer
        else:
            ckpt.optimizer = optimizer

        ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=3)

        # if a checkpoint exists, restore the latest checkpoint
        if ckpt_manager.latest_checkpoint:
            ckpt.restore(ckpt_manager.latest_checkpoint).expect_partial()
            current_epoch = ckpt.current_epoch.numpy()
            print("Latest checkpoint restored -- Model trained for {} epochs".format(current_epoch))
        else:
            print("Checkpoint not found. Train from scratch")
            current_epoch = 0
        
        if not args.load_optimizer:
            ckpt.optimizer = optimizer


    """
    Bootstrap
    """
    sample = next(iter(dist_dataset))
    dist_train_step(sample)


    """
    Training loop
    """
    for epoch in range(current_epoch, epochs):
        print("*************** Epoch {:02d}/{:02d} ***************".format(epoch + 1, epochs))
        begin_epoch_time = time.perf_counter()

        for step in range(steps_per_epoch):
            begin_step_time = time.perf_counter()
            loss = dist_train_step(next(iterator))
            print("Step {: <6d}Loss: {: <20f}Elapsed: {}".format(
                step + 1,
                loss.numpy(),
                time.perf_counter() - begin_step_time,
            ))

        print("\nEpoch's elapsed time: {}\n".format(time.perf_counter() - begin_epoch_time))

        ckpt.current_epoch.assign_add(1)
        
        # Checkpoint the model
        ckpt_save_path = ckpt_manager.save()
        print ('\nSaving checkpoint for epoch {} at {}'.format(epoch + 1, ckpt_save_path))