def main(): parser = argparse.ArgumentParser() parser.add_argument("--jsonl-path", type=str, default="data/retriever/nq-train.jsonl") parser.add_argument("--tfrecord-text-path", type=str, default="data/retriever/V3/N5000-TEXT") parser.add_argument("--tfrecord-int-path", type=str, default="data/retriever/V3/N5000-INT") parser.add_argument( "--ctx-source-path", type=str, default="gs://openqa-dpr/data/wikipedia_split/shards-42031") parser.add_argument("--ctx-tokenized-path", type=str, default="data/wikipedia_split/shards-42031-tfrecord") parser.add_argument("--shard-size", type=int, default=42031) parser.add_argument("--max-context-length", type=int, default=const.MAX_CONTEXT_LENGTH) parser.add_argument("--max-query-length", type=int, default=const.MAX_QUERY_LENGTH) parser.add_argument("--records-per-file", type=int, default=5000) parser.add_argument("--shuffle", type=eval, default=const.SHUFFLE) parser.add_argument("--shuffle-seed", type=int, default=const.SHUFFLE_SEED) parser.add_argument("--pretrained-model", type=str, default=const.PRETRAINED_MODEL) parser.add_argument("--qas-path", type=str, default=const.QUERY_PATH) parser.add_argument("--qas-tfrecord-path", type=str, default="data/qas/nq-test.tfrecord") args = parser.parse_args() # TODO: build tfrecord dataset # build_tfrecord_text_data_from_jsonl( # input_path="data/retriever/V2/vicovid-train.jsonl", # out_dir="data/retriever/V2/TEXT", # records_per_file=5000, # num_hard_negatives=1 # ) build_tfrecord_int_data_from_tfrecord_text_data( input_path='data/retriever/V2/TEXT', out_dir="data/retriever/V2/INT", tokenizer=get_tokenizer(model_name='NlpHUST/vibert4news-base-cased', prefix='pretrained'), shuffle=True, shuffle_seed=123, num_hard_negatives=1)
def build_tfrecord_tokenized_data_for_qas_ver2(pretrained_model: str, qas_path: str, out_dir: str, prefix): tokenizer = get_tokenizer(model_name=pretrained_model, prefix=prefix) text_dataset = tf.data.TextLineDataset(qas_path) def _transform(): for element in text_dataset: question, answers = element.numpy().decode().split("\t") question_ids = tokenizer.encode(question) question_ids = tf.convert_to_tensor(question_ids) yield tf.sparse.from_dense(question_ids) sparse_dataset = tf.data.Dataset.from_generator( _transform, output_signature=tf.SparseTensorSpec(shape=[None], dtype=tf.int32)) def _serialize(element): element_serialized = tf.io.serialize_tensor( tf.io.serialize_sparse(element)) features = { 'question_serialized': tf.train.Feature(bytes_list=tf.train.BytesList( value=[element_serialized.numpy()])) } example = tf.train.Example(features=tf.train.Features( feature=features)) return example.SerializeToString() def _generator(): count = 0 for element in sparse_dataset: yield _serialize(element) count += 1 print("Count: {}".format(count)) dataset = tf.data.Dataset.from_generator(_generator, output_signature=tf.TensorSpec( [], tf.string)) file_name = os.path.basename(qas_path).split(".")[0] writer = tf.data.experimental.TFRecordWriter( os.path.join(out_dir, "{}-ver2.tfrecord".format(file_name))) writer.write(dataset)
def validate( reader, strategy, dataset, params=None ): print("Validating...") if params is not None: global args args = params def dist_forward_step(input_ids): print("This function is tracing") def step_fn(input_ids): attention_mask = tf.cast(input_ids > 0, dtype=tf.int32) start_logits, end_logits = reader( input_ids=input_ids, attention_mask=attention_mask, training=False ) return start_logits, end_logits per_replica_logits = strategy.run(step_fn, args=(input_ids,)) return per_replica_logits if not args.disable_tf_function: dist_forward_step = tf.function(dist_forward_step) def value_fn_template(ctx, indices, tensors): start, end = indices[ctx.replica_id_in_sync_group] return tensors[start : end] processes = ProcessPool(processes=os.cpu_count()) tokenizer = get_tokenizer(model_name=args.pretrained_model, prefix=args.prefix) get_best_span_partial = partial(get_best_span, max_answer_length=args.max_answer_length, tokenizer=tokenizer) iterator = iter(dataset) em_hits = [] match_stats = [] for element in tqdm(iterator): answers_serialized = element['answers'] question = element['question'] passage_offset = element['passage_offset'] input_ids = element['input_ids'] if strategy.num_replicas_in_sync > 1: reduced_input_ids = tf.concat(input_ids.values, axis=0) else: reduced_input_ids = input_ids global_batch_size = reduced_input_ids.shape[0] # forward pass if global_batch_size < args.batch_size * strategy.num_replicas_in_sync: base_replica_batch_size = args.batch_size flag = False while True: spread, global_batch_size, base_replica_batch_size = spread_samples_equally( global_batch_size=global_batch_size, num_replicas=strategy.num_replicas_in_sync, base_replica_batch_size=base_replica_batch_size, init_batch_size=args.batch_size ) if len(spread) > 1: indices = [] idx = 0 for num in spread: indices.append((idx, idx + num)) idx += num value_fn = partial(value_fn_template, indices=indices, tensors=reduced_input_ids) reduced_input_ids = reduced_input_ids[base_replica_batch_size * strategy.num_replicas_in_sync:] dist_input_ids = strategy.experimental_distribute_values_from_function(value_fn) start_logits, end_logits = dist_forward_step(dist_input_ids) if not flag: if strategy.num_replicas_in_sync > 1: global_start_logits = tf.concat(start_logits.values, axis=0) global_end_logits = tf.concat(end_logits.values, axis=0) else: global_start_logits = start_logits global_end_logits = end_logits flag = True else: if strategy.num_replicas_in_sync > 1: global_start_logits = tf.concat([global_start_logits, *start_logits.values], axis=0) global_end_logits = tf.concat([global_end_logits, *start_logits.values], axis=0) else: global_start_logits = tf.concat([global_start_logits, start_logits], axis=0) global_end_logits = tf.concat([global_end_logits, end_logits], axis=0) if global_batch_size == 0: break else: start_logits, end_logits = dist_forward_step(reduced_input_ids) if not flag: if strategy.num_replicas_in_sync > 1: global_start_logits = start_logits.values[0] global_end_logits = end_logits.values[0] else: global_start_logits = start_logits global_end_logits = end_logits flag = True else: if strategy.num_replicas_in_sync > 1: global_start_logits = tf.concat([global_start_logits, start_logits.values[0]], axis=0) global_end_logits = tf.concat([global_end_logits, end_logits.values[0]], axis=0) else: global_start_logits = tf.concat([global_start_logits, start_logits], axis=0) global_end_logits = tf.concat([global_end_logits, end_logits], axis=0) break else: start_logits, end_logits = dist_forward_step(input_ids) if strategy.num_replicas_in_sync > 1: global_start_logits = tf.concat(start_logits.values, axis=0) global_end_logits = tf.concat(end_logits.values, axis=0) else: global_start_logits = start_logits global_end_logits = end_logits if strategy.num_replicas_in_sync > 1: input_ids = tf.concat(input_ids.values, axis=0) passage_offset = tf.concat(passage_offset.values, axis=0) answers_serialized = tf.concat(answers_serialized.values, axis=0) question = tf.concat(question.values, axis=0) question = question.numpy().tolist() question = [q.decode() for q in question] sentence_ids = tf.RaggedTensor.from_tensor(input_ids, padding=tokenizer.pad_token_id) sentence_ids = sentence_ids.to_list() ctx_ids = [ids[offset:] for ids, offset in zip(sentence_ids, passage_offset)] start_logits = global_start_logits.numpy().tolist() start_logits = [logits[offset : offset + len(ctx)] for logits, offset, ctx in zip(start_logits, passage_offset, ctx_ids)] end_logits = global_end_logits.numpy().tolist() end_logits = [logits[offset : offset + len(ctx)] for logits, offset, ctx in zip(end_logits, passage_offset, ctx_ids)] best_spans = processes.starmap(get_best_span_partial, zip(start_logits, end_logits, ctx_ids)) answers = [] for ans in answers_serialized: ans_sparse = tf.io.parse_tensor(ans, out_type=tf.string) ans_values = tf.io.parse_tensor(ans_sparse[1], out_type=tf.string) ans_values = [answer.numpy().decode() for answer in ans_values] answers.append(ans_values) hits = processes.starmap(compare_spans, zip(answers, best_spans)) passages = [tokenizer.decode(ids) for ids in ctx_ids] stats = [ { "question": q, "answers": ans, "passage": psg, "predicted": span, "hit": hit } for q, ans, span, psg, hit in zip(question, answers, best_spans, passages, hits) ] match_stats.extend(stats) em_hits.extend(hits) print("done") print("-----------------------------------------------------------") return em_hits, match_stats
kwargs = { 'ctx_source_path': 'data/wikipedia_split/vi_covid_subset_ctx_source.tsv', 'pretrained_model': 'NlpHUST/vibert4news-base-cased', 'use_pooler': False, 'max_query_length': 64 } args = Namespace(**kwargs) index_path = "indexer/vicovid_inbatch_batch8_query64" indexer = DenseFlatIndexer(buffer_size=50000) index_files = glob.glob("{}/*".format(index_path)) print("Deserializing indexer from disk... ") indexer.deserialize(index_path) all_docs = load_ctx_sources(args.ctx_source_path) tokenizer = get_tokenizer(model_name='NlpHUST/vibert4news-base-cased', prefix='pretrained') question_encoder = load_checkpoint(args) tokenizer = get_tokenizer(model_name='NlpHUST/vibert4news-base-cased') retrieve_func = partial(retrieve, indexer=indexer, question_encoder=question_encoder, tokenizer=tokenizer, top_docs=10, all_docs=all_docs) print("done")
def build_tfrecord_tokenized_data_for_ctx_sources( pretrained_model: str, ctx_source_path: str, out_dir: str, max_context_length: int = 256, shard_size: int = 42031, prefix=None): tokenizer = get_tokenizer(model_name=pretrained_model, prefix=prefix) ctx_source_files = glob.glob("{}/*.tsv".format(ctx_source_path)) ctx_source_files.sort() text_dataset = None for ctx_source in ctx_source_files: df = pd.read_csv(ctx_source, sep="\t", header=None, names=['id', 'text', 'title']) if text_dataset is None: text_dataset = df else: text_dataset = pd.concat([text_dataset, df], axis='index', ignore_index=True) def _transform(): count = 0 for _, element in text_dataset.iterrows(): id, text, title = element.id, element.text, element.title id = str(id) text = str(text) title = str(title) passage_id = "wiki:{}".format(id) text_tokens = tokenizer.tokenize(text) title_tokens = tokenizer.tokenize(title) sent_tokens = [tokenizer.cls_token] + title_tokens \ + [tokenizer.sep_token] + text_tokens + [tokenizer.sep_token] if len(sent_tokens) < max_context_length: sent_tokens += [tokenizer.pad_token ] * (max_context_length - len(sent_tokens)) sent_tokens = sent_tokens[:max_context_length] sent_tokens[-1] = tokenizer.sep_token context_ids = tokenizer.convert_tokens_to_ids(sent_tokens) context_ids = tf.convert_to_tensor(context_ids, dtype=tf.int32) count += 1 print("Count: {}".format(count)) yield { 'context_ids': context_ids, 'passage_id': tf.constant(passage_id) } dataset = tf.data.Dataset.from_generator( _transform, output_signature={ 'context_ids': tf.TensorSpec([max_context_length], tf.int32), 'passage_id': tf.TensorSpec([], tf.string) }) def _serialize(context_ids, passage_id): features = { 'context_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=context_ids)), 'passage_id': tf.train.Feature(bytes_list=tf.train.BytesList( value=[passage_id.numpy()])) } example = tf.train.Example(features=tf.train.Features( feature=features)) return example.SerializeToString() dataset = dataset.map(lambda x: tf.py_function( _serialize, inp=[x['context_ids'], x['passage_id']], Tout=tf.string), num_parallel_calls=tf.data.AUTOTUNE, deterministic=True) dataset = dataset.window(shard_size) if not os.path.exists(out_dir): os.makedirs(out_dir) idx = 0 for window in dataset: writer = tf.data.experimental.TFRecordWriter( os.path.join(out_dir, "psgs_subset_{:02d}.tfrecord".format(idx))) writer.write(window) idx += 1
def validate(dataset, strategy, ranker, reader, params=None): if params is not None: global args args = params def dist_forward_ranker_step(input_ids): print("This function is tracing") def step_fn(input_ids): input_ids = tf.reshape(input_ids, [-1, args.max_sequence_length]) attention_mask = tf.cast(input_ids > 0, dtype=tf.int32) rank_logits = ranker(input_ids=input_ids, attention_mask=attention_mask, training=False) return tf.reshape(rank_logits, [args.batch_size, -1]) per_replica_logits = strategy.run(step_fn, args=(input_ids, )) return per_replica_logits def dist_forward_reader_step(input_ids): print("This function is tracing") def step_fn(input_ids): attention_mask = tf.cast(input_ids > 0, dtype=tf.int32) start_logits, end_logits = reader(input_ids=input_ids, attention_mask=attention_mask, training=False) return start_logits, end_logits per_replica_results = strategy.run(step_fn, args=(input_ids, )) return per_replica_results if not args.disable_tf_function: dist_forward_ranker_step = tf.function(dist_forward_ranker_step) dist_forward_reader_step = tf.function(dist_forward_reader_step) def value_fn_template(ctx, pool_tensors): return pool_tensors[ctx.replica_id_in_sync_group] processes = ProcessPool(processes=os.cpu_count()) tokenizer = get_tokenizer(model_name=args.pretrained_model, prefix=args.prefix) get_best_span_partial = partial(get_best_span, max_answer_length=args.max_answer_length, tokenizer=tokenizer) iterator = iter(dataset) em_hits = [] match_stats = [] for element in tqdm(iterator): answers_serialized = element['answers'] question = element['question'] input_ids = element[ 'passages/sequence_ids'] # bsz x num_passages x max_sequence_length passage_offsets = element[ 'passages/passage_offset'] # bsz x num_passages reduced_input_ids = tf.concat(input_ids.values, axis=0) per_replica_passage_offsets = strategy.experimental_local_results( passage_offsets) global_batch_size = reduced_input_ids.shape[0] if global_batch_size < args.batch_size * strategy.num_replicas_in_sync: # TODO: add code in case batch is not divisible aggregated_input_ids = tf.concat(input_ids.values, axis=0) padded_size = args.batch_size * strategy.num_replicas_in_sync - global_batch_size padded_input_ids = tf.zeros( [padded_size, args.max_passages, args.max_sequence_length], dtype=tf.int32) input_ids = tf.concat([aggregated_input_ids, padded_input_ids], axis=0) pool_input_ids = tf.split( input_ids, num_or_size_splits=strategy.num_replicas_in_sync, axis=0) value_fn_for_input_ids = partial(value_fn_template, pool_tensors=pool_input_ids) input_ids = strategy.experimental_distribute_values_from_function( value_fn_for_input_ids) aggregated_per_replica_passage_offsets = tf.concat( per_replica_passage_offsets, axis=0) lack_size = args.batch_size * strategy.num_replicas_in_sync - aggregated_per_replica_passage_offsets.shape[ 0] padded_per_replica_passage_offsets = tf.zeros( [lack_size, args.max_passages], dtype=tf.int32) per_replica_passage_offsets = tf.concat([ aggregated_per_replica_passage_offsets, padded_per_replica_passage_offsets ], axis=0) per_replica_passage_offsets = tf.split( per_replica_passage_offsets, num_or_size_splits=strategy.num_replicas_in_sync) rank_logits = dist_forward_ranker_step(input_ids) rank_logits = strategy.experimental_local_results(rank_logits) selected_passage_idxs = [ tf.cast(tf.argmax(logits, axis=-1), dtype=tf.int32) for logits in rank_logits ] # num_replicas x batch_size selected_passage_offsets = [] per_replica_input_ids = strategy.experimental_local_results( input_ids ) # num_replicas x batch_sizse x max_passages x max_sequence_length selected_input_ids = [] for sequence_ids, psg_offsets, passage_idxs in zip( per_replica_input_ids, per_replica_passage_offsets, selected_passage_idxs): range_idxs = tf.range(sequence_ids.shape[0], dtype=tf.int32) indices = tf.concat([ tf.expand_dims(range_idxs, axis=1), tf.expand_dims(passage_idxs, axis=1) ], axis=1) selected_passage_offsets.append(tf.gather_nd(psg_offsets, indices)) selected_input_ids.append(tf.gather_nd(sequence_ids, indices)) value_fn = partial(value_fn_template, pool_tensors=selected_input_ids) dist_selected_input_ids = strategy.experimental_distribute_values_from_function( value_fn) start_logits, end_logits = dist_forward_reader_step( input_ids=dist_selected_input_ids) sentence_ids = tf.concat(dist_selected_input_ids.values, axis=0) sentence_ids = tf.RaggedTensor.from_tensor( sentence_ids, padding=tokenizer.pad_token_id) sentence_ids = sentence_ids.to_list() sentence_ids = sentence_ids[:global_batch_size] selected_passage_offsets = tf.concat(selected_passage_offsets, axis=0) selected_passage_offsets = selected_passage_offsets[:global_batch_size] ctx_ids = [ ids[offset:] for ids, offset in zip(sentence_ids, selected_passage_offsets) ] start_logits = tf.concat(start_logits.values, axis=0) start_logits = start_logits.numpy().tolist() start_logits = start_logits[:global_batch_size] start_logits = [ logits[offset:offset + len(ctx)] for logits, offset, ctx in zip( start_logits, selected_passage_offsets, ctx_ids) ] end_logits = tf.concat(end_logits.values, axis=0) end_logits = end_logits.numpy().tolist() end_logits = end_logits[:global_batch_size] end_logits = [ logits[offset:offset + len(ctx)] for logits, offset, ctx in zip( end_logits, selected_passage_offsets, ctx_ids) ] best_spans = processes.starmap(get_best_span_partial, zip(start_logits, end_logits, ctx_ids)) answers_serialized = tf.concat(answers_serialized.values, axis=0) question = tf.concat(question.values, axis=0) answers = [] for ans in answers_serialized: ans_sparse = tf.io.parse_tensor(ans, out_type=tf.string) ans_values = tf.io.parse_tensor(ans_sparse[1], out_type=tf.string) ans_values = [answer.numpy().decode() for answer in ans_values] answers.append(ans_values) question = question.numpy().tolist() question = [q.decode() for q in question] hits = processes.starmap(compare_spans, zip(answers, best_spans)) passages = [tokenizer.decode(ids) for ids in ctx_ids] selected_passage_idxs = tf.concat(selected_passage_idxs, axis=0) selected_passage_idxs = selected_passage_idxs.numpy().tolist() stats = [{ "question": q, "answers": ans, "passage": psg, "predicted": span, "retriever_rank": idx + 1, "hit": hit } for q, ans, span, idx, psg, hit in zip( question, answers, best_spans, selected_passage_idxs, passages, hits)] match_stats.extend(stats) em_hits.extend(hits) print("done") print("-----------------------------------------------------------") return em_hits, match_stats
def main(): parser = argparse.ArgumentParser() parser.add_argument("--query-path", type=str, default=const.QUERY_PATH, help="Path to the queries used to test retriever") parser.add_argument("--ctx-source-path", type=str, default=const.CTX_SOURCE_PATH, help="Path to the file containg all passages") parser.add_argument("--checkpoint-path", type=str, default=const.CHECKPOINT_PATH, help="Path to the checkpointed model") parser.add_argument( "--top-k", type=int, default=const.TOP_K, help="Number of documents that expects to be returned by retriever") parser.add_argument("--batch-size", type=int, default=16, help="Batch size when embedding questions") parser.add_argument("--index-path", type=str, default=const.INDEX_PATH, help="Path to indexed database") parser.add_argument("--pretrained-model", type=str, default=const.PRETRAINED_MODEL) parser.add_argument("--reader-data-path", type=str, default=const.READER_DATA_PATH) parser.add_argument("--result-path", type=str, default=const.RESULT_PATH) parser.add_argument("--embeddings-path", type=str, default=const.EMBEDDINGS_DIR) parser.add_argument("--force-create-index", type=eval, default=False) parser.add_argument("--qas-tfrecord-path", type=str, default=const.QAS_TFRECORD_PATH) parser.add_argument("--max-query-length", type=int, default=const.MAX_QUERY_LENGTH) parser.add_argument("--disable-tf-function", type=eval, default=False) parser.add_argument("--tpu", type=str, default=const.TPU_NAME) parser.add_argument("--use-pooler", type=eval, default=True) parser.add_argument("--prefix", type=str, default='pretrained') global args args = parser.parse_args() model_type = os.path.basename(args.checkpoint_path) embeddings_path = os.path.join(args.embeddings_path, "shards-42031", model_type) args_dict = {**args.__dict__, "embeddings_path": embeddings_path} configs = ["{}: {}".format(k, v) for k, v in args_dict.items()] configs_string = "\t" + "\n\t".join(configs) + "\n" print("************************* Configurations *************************") print(configs_string) print( "----------------------------------------------------------------------------------------------------------------------" ) config_path = "configs/{}/{}/config.yml".format( __file__.rstrip(".py"), datetime.now().strftime("%Y-%m-%d %H:%M:%S")) write_config(config_path, args_dict) try: # detect TPUs resolver = tf.distribute.cluster_resolver.TPUClusterResolver( tpu=args.tpu) # TPU detection tf.config.experimental_connect_to_cluster(resolver) tf.tpu.experimental.initialize_tpu_system(resolver) strategy = tf.distribute.TPUStrategy(resolver) except Exception: # detect GPUs devices = tf.config.list_physical_devices("GPU") # [tf.config.experimental.set_memory_growth(device, True) for device in devices] if devices: strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) else: strategy = tf.distribute.get_strategy() index_path = os.path.join(args.index_path, model_type) if not os.path.exists(index_path): os.makedirs(index_path) indexer = create_or_retrieve_indexer(index_path=index_path, embeddings_path=embeddings_path) # exit(0) # only create index question_encoder = load_checkpoint(checkpoint_path=args.checkpoint_path, strategy=strategy) questions, answers = load_qas_test_data() tokenizer = get_tokenizer(model_name=args.pretrained_model, prefix=args.prefix) dataset = prepare_dataset(args.qas_tfrecord_path, strategy=strategy, tokenizer=tokenizer, max_query_length=args.max_query_length) question_embeddings = generate_embeddings( question_encoder=question_encoder, dataset=dataset, strategy=strategy) top_ids_and_scores = search_knn(indexer=indexer, question_embeddings=question_embeddings) all_docs = load_ctx_sources() print("Validating... ") top_k_hits_path = os.path.join(args.result_path, model_type, "top_k_hits.txt") if not os.path.exists(os.path.dirname(top_k_hits_path)): os.makedirs(os.path.dirname(top_k_hits_path)) start_time = time.perf_counter() questions_doc_hits = validate(top_ids_and_scores=top_ids_and_scores, answers=answers, ctx_sources=all_docs, top_k_hits_path=top_k_hits_path) print("done in {}s !".format(time.perf_counter() - start_time)) print( "----------------------------------------------------------------------------------------------------------------------" ) print("Generating reader data... ") reader_data_path = os.path.join(args.reader_data_path, model_type, "reader_data.json") if not os.path.exists(os.path.dirname(reader_data_path)): os.makedirs(os.path.dirname(reader_data_path)) start_time = time.perf_counter() save_results(questions=questions, answers=answers, all_docs=all_docs, top_passages_and_scores=top_ids_and_scores, per_question_hits=questions_doc_hits, out_file=reader_data_path) print("done in {}s !".format(time.perf_counter() - start_time)) print( "----------------------------------------------------------------------------------------------------------------------" )
def main(): parser = argparse.ArgumentParser() parser.add_argument("--train-data-size", type=int, default=const.TRAIN_DATA_SIZE) parser.add_argument("--data-path", type=str, default=const.DATA_PATH, help="Path to the `.tfrecord` data. Data in this file is already preprocessed into tensor format") parser.add_argument("--max-context-length", type=int, default=const.MAX_CONTEXT_LENGTH, help="Maximum length of a document") parser.add_argument("--max-query-length", type=int, default=const.MAX_QUERY_LENGTH, help="Maximum length of a question") parser.add_argument("--batch-size", type=int, default=const.BATCH_SIZE, help="Batch size on each compute device") parser.add_argument("--epochs", type=int, default=const.EPOCHS) parser.add_argument("--learning-rate", type=float, default=const.LEARNING_RATE) parser.add_argument("--warmup-steps", type=int, default=const.WARMUP_STEPS) parser.add_argument("--adam-eps", type=float, default=const.ADAM_EPS) parser.add_argument("--adam-betas", type=eval, default=const.ADAM_BETAS) parser.add_argument("--weight-decay", type=float, default=const.WEIGHT_DECAY) parser.add_argument("--max-grad-norm", type=float, default=const.MAX_GRAD_NORM) parser.add_argument("--shuffle", type=eval, default=const.SHUFFLE) parser.add_argument("--seed", type=int, default=const.SHUFFLE_SEED) parser.add_argument("--checkpoint-path", type=str, default=const.CHECKPOINT_PATH) parser.add_argument("--ctx-encoder-trainable", type=eval, default=const.CTX_ENCODER_TRAINABLE, help="Whether the context encoder's weights are trainable") parser.add_argument("--question-encoder-trainable", type=eval, default=const.QUESTION_ENCODER_TRAINABLE, help="Whether the question encoder's weights are trainable") parser.add_argument("--tpu", type=str, default=const.TPU_NAME) parser.add_argument("--loss-fn", type=str, choices=['inbatch', 'threelevel', 'twolevel', 'hardnegvsneg', 'hardnegvsnegsoftmax', 'threelevelsoftmax'], default='threelevel') parser.add_argument("--use-pooler", type=eval, default=True) parser.add_argument("--load-optimizer", type=eval, default=True) parser.add_argument("--tokenizer", type=str, default="bert-base-uncased") parser.add_argument("--question-pretrained-model", type=str, default='bert-base-uncased') parser.add_argument("--context-pretrained-model", type=str, default='bigbird') parser.add_argument("--within-size", type=int, default=8) parser.add_argument("--prefix", type=str, default=None) global args args = parser.parse_args() args_dict = args.__dict__ configs = ["{}: {}".format(k, v) for k, v in args_dict.items()] configs_string = "\t" + "\n\t".join(configs) + "\n" print("************************* Configurations *************************") print(configs_string) print("----------------------------------------------------------------------------------------------------------------------") file_name = os.path.basename(__file__) config_path = "configs/{}/{}/config.yml".format(file_name.rstrip(".py"), datetime.now().strftime("%Y-%m-%d %H-%M-%S")) write_config(config_path, args_dict) epochs = args.epochs global strategy try: # detect TPUs resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=args.tpu) # TPU detection tf.config.experimental_connect_to_cluster(resolver) tf.tpu.experimental.initialize_tpu_system(resolver) strategy = tf.distribute.TPUStrategy(resolver) except Exception: # detect GPUs devices = tf.config.list_physical_devices("GPU") # [tf.config.experimental.set_memory_growth(device, True) for device in devices] if devices: strategy = tf.distribute.MirroredStrategy() else: strategy = tf.distribute.get_strategy() tf.random.set_seed(args.seed) tokenizer = get_tokenizer(model_name=args.tokenizer, prefix=args.prefix) """Data pipeline 1. Load retriever data (in `.tfrecord` format, stored serialized `tf.int32` tensor) 2. Padding sequence to the same length 3. Shuffle: You should shuffle before batch to guarantee that each data sample can be batched with different data samples in different epochs 4. Repeating data: repeat to produce indefininte data stream 5. Batching dataset 6. Prefetching dataset (to speed up training) """ dataset = biencoder_manipulator.load_retriever_tfrecord_int_data( input_path=args.data_path, shuffle=args.shuffle, shuffle_seed=args.seed ) dataset = biencoder_manipulator.pad( dataset, sep_token_id=tokenizer.sep_token_id, max_context_length=args.max_context_length, max_query_length=args.max_query_length ) if args.loss_fn == 'inbatch': dataset = dataset.map( lambda x: { 'question': x['question'], 'contexts': x['contexts'][:2] }, num_parallel_calls=tf.data.AUTOTUNE, ) else: dataset = dataset.map( lambda x: { 'question': x['question'], 'contexts': x['contexts'][:args.within_size] } ) dataset = dataset.shuffle(buffer_size=60000) dataset = dataset.repeat() dataset = dataset.batch(args.batch_size) dataset = dataset.prefetch(tf.data.AUTOTUNE) """ Distribute the dataset """ dist_dataset = strategy.distribute_datasets_from_function( lambda _: dataset ) iterator = iter(dist_dataset) """ Set up for distributed training """ steps_per_epoch = args.train_data_size // (args.batch_size * strategy.num_replicas_in_sync) global optimizer global loss_fn global retriever with strategy.scope(): # Instantiate question encoder question_encoder = get_encoder( model_name=args.question_pretrained_model, args=args, trainable=args.question_encoder_trainable, prefix=args.prefix ) # Instantiate context encoder context_encoder = get_encoder( model_name=args.context_pretrained_model, args=args, trainable=args.ctx_encoder_trainable, prefix=args.prefix ) retriever = models.BiEncoder( question_model=question_encoder, ctx_model=context_encoder, use_pooler=args.use_pooler ) # Instantiate the optimizer optimizer = optimizers.get_adamw( steps_per_epoch=steps_per_epoch, warmup_steps=args.warmup_steps, epochs=args.epochs, learning_rate=args.learning_rate, eps=args.adam_eps, beta_1=args.adam_betas[0], beta_2=args.adam_betas[1], weight_decay=args.weight_decay, ) # Define loss function if args.loss_fn == 'threelevel': loss_fn = biencoder.ThreeLevelDPRLoss(batch_size=args.batch_size, within_size=args.within_size) elif args.loss_fn == 'twolevel': loss_fn = biencoder.TwoLevelDPRLoss(batch_size=args.batch_size, within_size=args.within_size) elif args.loss_fn == "hardnegvsneg": loss_fn = biencoder.HardNegVsNegDPRLoss(batch_size=args.batch_size, within_size=args.within_size) elif args.loss_fn == 'hardnegvsnegsoftmax': loss_fn = biencoder.HardNegVsNegSoftMaxDPRLoss(batch_size=args.batch_size, within_size=args.within_size) elif args.loss_fn == 'threelevelsoftmax': loss_fn = biencoder.ThreeLevelSoftMaxDPRLoss(batch_size=args.batch_size, within_size=args.within_size) else: loss_fn = biencoder.InBatchDPRLoss(batch_size=args.batch_size) """ Distributed train step """ dist_train_step = get_dist_train_step( model_name=args.context_pretrained_model ) """ Configure checkpoint """ with strategy.scope(): checkpoint_path = args.checkpoint_path ckpt = tf.train.Checkpoint( model=retriever, current_epoch=tf.Variable(0) ) if not args.load_optimizer: tmp_optimizer = copy.deepcopy(optimizer) ckpt.optimizer = tmp_optimizer else: ckpt.optimizer = optimizer ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=3) # if a checkpoint exists, restore the latest checkpoint if ckpt_manager.latest_checkpoint: ckpt.restore(ckpt_manager.latest_checkpoint).expect_partial() current_epoch = ckpt.current_epoch.numpy() print("Latest checkpoint restored -- Model trained for {} epochs".format(current_epoch)) else: print("Checkpoint not found. Train from scratch") current_epoch = 0 if not args.load_optimizer: ckpt.optimizer = optimizer """ Bootstrap """ sample = next(iter(dist_dataset)) dist_train_step(sample) """ Training loop """ for epoch in range(current_epoch, epochs): print("*************** Epoch {:02d}/{:02d} ***************".format(epoch + 1, epochs)) begin_epoch_time = time.perf_counter() for step in range(steps_per_epoch): begin_step_time = time.perf_counter() loss = dist_train_step(next(iterator)) print("Step {: <6d}Loss: {: <20f}Elapsed: {}".format( step + 1, loss.numpy(), time.perf_counter() - begin_step_time, )) print("\nEpoch's elapsed time: {}\n".format(time.perf_counter() - begin_epoch_time)) ckpt.current_epoch.assign_add(1) # Checkpoint the model ckpt_save_path = ckpt_manager.save() print ('\nSaving checkpoint for epoch {} at {}'.format(epoch + 1, ckpt_save_path))