def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--data-path", type=str, default="gs://openqa-dpr/data/reader/nq/dev/dev.tfrecord")
    parser.add_argument("--max-sequence-length", type=int, default=256)
    parser.add_argument("--tpu", type=str, default="tpu-v3")
    parser.add_argument("--pretrained-model", type=str, default="bert-base-uncased")
    parser.add_argument("--batch-size", type=int, default=128)
    parser.add_argument("--checkpoint-path", type=str, default="gs://openqa-dpr/checkpoints/reader/baseline")
    parser.add_argument("--disable-tf-function", type=eval, default=False)
    parser.add_argument("--max-answer-length", type=int, default=10)
    parser.add_argument("--res-dir", type=str, default="results/reader")
    parser.add_argument("--pretrained-model", type=str, default='bert-base-uncased')
    parser.add_argument("--prefix", type=str, default='pretrained')

    global args
    args = parser.parse_args()
    args_dict = args.__dict__
    checkpoint_type = os.path.basename(args.checkpoint_path)
    args_dict['res_dir'] = os.path.join(args_dict['res_dir'], checkpoint_type)

    configs = ["{}: {}".format(k, v) for k, v in args_dict.items()]
    configs_string = "\t" + "\n\t".join(configs) + "\n"
    print("************************* Configurations *************************")
    print(configs_string)
    print("----------------------------------------------------------------------------------------------------------------------")

    config_path = "configs/{}/{}/config.yml".format(os.path.basename(__file__).rstrip(".py"), datetime.now().strftime("%Y-%m-%d %H-%M-%S"))
    write_config(config_path, args_dict)

    try: # detect TPUs
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=args.tpu) # TPU detection
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.TPUStrategy(resolver)
    except Exception: # detect GPUs
        devices = tf.config.list_physical_devices("GPU")
        # [tf.config.experimental.set_memory_growth(device, True) for device in devices]
        if devices:
            strategy = tf.distribute.MirroredStrategy()
        else:
            strategy = tf.distribute.get_strategy()

    dataset = load_dataset(data_path=args.data_path, strategy=strategy)
    reader = load_checkpoint(
        pretrained_model=args.pretrained_model,
        checkpoint_path=args.checkpoint_path,
        strategy=strategy
    )

    em_hits, match_stats = validate(
        reader=reader,
        strategy=strategy,
        dataset=dataset
    )

    save_results(em_hits=em_hits, match_stats=match_stats, out_dir=args.res_dir)
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--checkpoint-path", type=str, default=const.CHECKPOINT_PATH)
    parser.add_argument("--ctx-source-shards-tfrecord", type=str, default=const.CTX_SOURCE_SHARDS_TFRECORD)
    parser.add_argument("--records-per-file", type=int, default=const.RECORDS_PER_FILE)
    parser.add_argument("--embeddings-path", type=str, default=const.EMBEDDINGS_DIR)
    parser.add_argument("--seed", type=int, default=const.SHUFFLE_SEED)
    parser.add_argument("--batch-size", type=int, default=const.EVAL_BATCH_SIZE)
    parser.add_argument("--tpu", type=str, default=const.TPU_NAME)
    parser.add_argument("--max-context-length", type=int, default=const.MAX_CONTEXT_LENGTH, help="Maximum length of a document")
    parser.add_argument("--pretrained-model", type=str, default=const.PRETRAINED_MODEL)
    parser.add_argument("--use-pooler", type=eval, default=True)
    parser.add_argument("--disable-tf-function", type=eval, default=False)
    parser.add_argument("--prefix", type=str, default='pretrained')

    global args
    args = parser.parse_args()
    model_type = os.path.basename(args.checkpoint_path)
    embeddings_path = os.path.join(args.embeddings_path, "shards-42031", model_type)
    args_dict = {**args.__dict__, "embeddings_path": embeddings_path}

    configs = ["{}: {}".format(k, v) for k, v in args_dict.items()]
    configs_string = "\t" + "\n\t".join(configs) + "\n"
    print("************************* Configurations *************************")
    print(configs_string)
    print("----------------------------------------------------------------------------------------------------------------------")

    config_path = "configs/{}/{}/config.yml".format(__file__.rstrip(".py"), datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
    write_config(config_path, args_dict)

    try: # detect TPUs
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=args.tpu) # TPU detection
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.TPUStrategy(resolver)
    except Exception: # detect GPUs
        devices = tf.config.list_physical_devices("GPU")
        # [tf.config.experimental.set_memory_growth(device, True) for device in devices]
        if devices:
            strategy = tf.distribute.MirroredStrategy(["GPU:0"])
        else:
            strategy = tf.distribute.get_strategy()

    tf.random.set_seed(args.seed)

    """
    Data pipeline
    """
    print("Data pipeline processing...")
    dataset = biencoder_manipulator.load_tfrecord_tokenized_data_for_ctx_sources(
        input_path=args.ctx_source_shards_tfrecord,
        max_context_length=args.max_context_length
    )
    dataset = dataset.batch(args.batch_size)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

    # Distribute the dataset
    dist_dataset = strategy.distribute_datasets_from_function(
        lambda _: dataset
    )

    print("done")
    print("----------------------------------------------------------------------------------------------------------------------")

    """
    Load checkpoint
    """
    print("Loading checkpoint...")
    checkpoint_path = args.checkpoint_path
    with strategy.scope():
        context_encoder = get_encoder(
            model_name=args.pretrained_model,
            args=args,
            trainable=False,
            prefix=args.prefix
        )

        retriever = tf.train.Checkpoint(ctx_model=context_encoder)
        root_ckpt = tf.train.Checkpoint(model=retriever)

        root_ckpt.restore(tf.train.latest_checkpoint(checkpoint_path)).expect_partial()
    
    print(tf.train.latest_checkpoint(checkpoint_path))

    print("done")
    print("----------------------------------------------------------------------------------------------------------------------")

    """
    Generate embeddings
    """
    print("Generate embeddings...")
    out_dir = embeddings_path
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    run(
        dataset=dist_dataset,
        strategy=strategy,
        context_encoder=context_encoder,
        out_dir=out_dir
    )

    print("done")
    print("----------------------------------------------------------------------------------------------------------------------")
Ejemplo n.º 3
0
    w_min_df = 3
    w_max_df = 0.9
    w_strip_accents = 'unicode'
    w_use_idf = 1
    w_smooth_idf = 1
    w_sublinear_tf = 1
    c_sublinear_tf = True
    c_lowercase = False
    c_strip_accents = 'unicode'
    c_analyzer = 'char'
    c_ngram_range = (5, 6)
    c_max_features = 50000


cfg = Config()
write_config(Config)

if not os.path.exists(cfg.root + cfg.model_name):
    os.mkdir(cfg.root + cfg.model_name)

train = pd.read_csv(cfg.train_fn, index_col=0)
test = pd.read_csv(TEST_FILENAME, index_col=0)
subm = pd.read_csv(SAMPLE_SUBMISSION_FILENAME)

train['none'] = 1 - train[LIST_CLASSES].max(axis=1)

train[COMMENT].fillna(NAN_WORD, inplace=True)
test[COMMENT].fillna(NAN_WORD, inplace=True)

if cfg.do_preprocess:
    train = preprocess(train)
Ejemplo n.º 4
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--train-data-size", type=int, default=263011)
    parser.add_argument("--data-path",
                        type=str,
                        default="gs://openqa-dpr/data/reader/nq/train")
    parser.add_argument("--max-sequence-length", type=int, default=256)
    parser.add_argument("--max-answers", type=int, default=10)
    parser.add_argument("--batch-size", type=int, default=16)
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--learning-rate", type=float, default=2e-5)
    parser.add_argument("--max-grad-norm", type=float, default=2.0)
    parser.add_argument("--warmup-steps", type=int, default=0)
    parser.add_argument("--weight-decay", type=float, default=0.0)
    parser.add_argument("--adam-eps", type=float, default=1e-8)
    parser.add_argument("--adam-betas", type=eval, default=(0.9, 0.999))
    parser.add_argument("--shuffle", type=eval, default=True)
    parser.add_argument("--seed", type=int, default=123)
    parser.add_argument("--checkpoint-path",
                        type=str,
                        default="gs://openqa-dpr/checkpoints/reader/baseline")
    parser.add_argument("--tpu", type=str, default="tpu-v3")
    parser.add_argument("--pretrained-model",
                        type=str,
                        default="bert-base-uncased")
    parser.add_argument("--load-optimizer", type=eval, default=True)
    parser.add_argument("--max-to-keep", type=int, default=50)
    parser.add_argument("--prefix", type=str, default='pretrained')

    args = parser.parse_args()
    args_dict = args.__dict__

    configs = ["{}: {}".format(k, v) for k, v in args_dict.items()]
    configs_string = "\t" + "\n\t".join(configs) + "\n"
    print("************************* Configurations *************************")
    print(configs_string)
    print(
        "----------------------------------------------------------------------------------------------------------------------"
    )

    config_path = "configs/{}/{}/config.yml".format(
        os.path.basename(__file__).rstrip(".py"),
        datetime.now().strftime("%Y-%m-%d %H-%M-%S"))
    write_config(config_path, args_dict)
    """
    Set up devices
    """
    try:  # detect TPUs
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=args.tpu)  # TPU detection
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.TPUStrategy(resolver)
    except Exception:  # detect GPUs
        devices = tf.config.list_physical_devices("GPU")
        # [tf.config.experimental.set_memory_growth(device, True) for device in devices]
        if devices:
            strategy = tf.distribute.MirroredStrategy()
        else:
            strategy = tf.distribute.get_strategy()

    tf.random.set_seed(args.seed)
    """
    Data pipeline
    """
    dataset = reader_manipulator.load_tfrecord_reader_train_data(
        input_path=args.data_path)

    dataset = reader_manipulator.transform_to_reader_train_dataset(
        dataset=dataset, max_sequence_length=256, max_answers=10)
    dataset = dataset.cache()
    dataset = dataset.shuffle(buffer_size=70000)
    dataset = dataset.repeat()
    dataset = dataset.batch(args.batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    """
    Distribute the dataset
    """
    dist_dataset = strategy.distribute_datasets_from_function(
        lambda _: dataset)
    iterator = iter(dataset)
    """
    Set up for distributed training
    """
    steps_per_epoch = args.train_data_size // (args.batch_size *
                                               strategy.num_replicas_in_sync)
    config = get_config(model_name=args.pretrained_model, prefix=args.prefix)
    with strategy.scope():
        encoder = get_encoder(model_name=args.pretrained_model,
                              args=args,
                              trainable=True,
                              prefix=args.prefix)
        encoder.bert.pooler.trainable = False

        reader = models.Reader(
            encoder=encoder,
            initializer_range=config.initializer_range,
        )

        optimizer = optimizers.get_adamw(
            epochs=args.epochs,
            steps_per_epoch=steps_per_epoch,
            warmup_steps=args.warmup_steps,
            learning_rate=args.learning_rate,
            weight_decay=args.weight_decay,
            eps=args.adam_eps,
            beta_1=args.adam_betas[0],
            beta_2=args.adam_betas[1],
        )

        loss_calculator = ReaderLossCalculator()
    """
    Distributed train step
    """
    @tf.function
    def dist_train_step(element):
        """The step function for one training step"""
        print("This function is tracing !")

        def step_fn(element):
            """The computation to be run on each compute device"""
            input_ids = element['input_ids']
            attention_mask = tf.cast(input_ids > 0, dtype=tf.int32)
            start_positions = element['start_positions']
            end_positions = element['end_positions']
            answer_mask = tf.cast(start_positions > 0, dtype=tf.float32)

            with tf.GradientTape() as tape:
                start_logits, end_logits = reader(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    training=True)

                loss = loss_calculator.compute_token_loss(
                    start_logits=start_logits,
                    end_logits=end_logits,
                    start_positions=start_positions,
                    end_positions=end_positions,
                    answer_mask=answer_mask)

                loss = tf.nn.compute_average_loss(
                    loss,
                    global_batch_size=args.batch_size *
                    strategy.num_replicas_in_sync)

            grads = tape.gradient(loss, reader.trainable_weights)
            grads = [tf.clip_by_norm(g, args.max_grad_norm) for g in grads]
            optimizer.apply_gradients(zip(grads, reader.trainable_weights))

            return loss

        per_replica_losses = strategy.run(step_fn, args=(element, ))
        return strategy.reduce(tf.distribute.ReduceOp.SUM,
                               per_replica_losses,
                               axis=None)

    """
    Configure checkpoint
    """
    with strategy.scope():
        checkpoint_path = args.checkpoint_path
        ckpt = tf.train.Checkpoint(model=reader, current_epoch=tf.Variable(0))
        if not args.load_optimizer:
            tmp_optimizer = copy.deepcopy(optimizer)
            ckpt.optimizer = tmp_optimizer
        else:
            ckpt.optimizer = optimizer

        ckpt_manager = tf.train.CheckpointManager(ckpt,
                                                  checkpoint_path,
                                                  max_to_keep=args.max_to_keep)

        # if a checkpoint exists, restore the latest checkpoint
        if ckpt_manager.latest_checkpoint:
            ckpt.restore(ckpt_manager.latest_checkpoint).expect_partial()
            current_epoch = ckpt.current_epoch.numpy()
            print("Latest checkpoint restored -- Model trained for {} epochs".
                  format(current_epoch))
        else:
            print("Checkpoint not found. Train from scratch")
            current_epoch = 0

        if not args.load_optimizer:
            ckpt.optimizer = optimizer
    """
    Bootstrap
    """
    sample = next(iter(dist_dataset))
    dist_train_step(sample)
    """
    Training loop
    """
    for epoch in range(current_epoch, args.epochs):
        print("*************** Epoch {:02d}/{:02d} ***************".format(
            epoch + 1, args.epochs))
        begin_epoch_time = time.perf_counter()

        for step in range(steps_per_epoch):
            begin_step_time = time.perf_counter()
            loss = dist_train_step(next(iterator))
            print("Step {: <6d}Loss: {: <20f}Elapsed: {}".format(
                step + 1,
                loss.numpy(),
                time.perf_counter() - begin_step_time,
            ))

        print("\nEpoch's elapsed time: {}\n".format(time.perf_counter() -
                                                    begin_epoch_time))

        ckpt.current_epoch.assign_add(1)

        # Checkpoint the model
        ckpt_save_path = ckpt_manager.save()
        print('\nSaving checkpoint for epoch {} at {}'.format(
            epoch + 1, ckpt_save_path))
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--query-path",
                        type=str,
                        default=const.QUERY_PATH,
                        help="Path to the queries used to test retriever")
    parser.add_argument("--ctx-source-path",
                        type=str,
                        default=const.CTX_SOURCE_PATH,
                        help="Path to the file containg all passages")
    parser.add_argument("--checkpoint-path",
                        type=str,
                        default=const.CHECKPOINT_PATH,
                        help="Path to the checkpointed model")
    parser.add_argument(
        "--top-k",
        type=int,
        default=const.TOP_K,
        help="Number of documents that expects to be returned by retriever")
    parser.add_argument("--batch-size",
                        type=int,
                        default=16,
                        help="Batch size when embedding questions")
    parser.add_argument("--index-path",
                        type=str,
                        default=const.INDEX_PATH,
                        help="Path to indexed database")
    parser.add_argument("--pretrained-model",
                        type=str,
                        default=const.PRETRAINED_MODEL)
    parser.add_argument("--reader-data-path",
                        type=str,
                        default=const.READER_DATA_PATH)
    parser.add_argument("--result-path", type=str, default=const.RESULT_PATH)
    parser.add_argument("--embeddings-path",
                        type=str,
                        default=const.EMBEDDINGS_DIR)
    parser.add_argument("--force-create-index", type=eval, default=False)
    parser.add_argument("--qas-tfrecord-path",
                        type=str,
                        default=const.QAS_TFRECORD_PATH)
    parser.add_argument("--max-query-length",
                        type=int,
                        default=const.MAX_QUERY_LENGTH)
    parser.add_argument("--disable-tf-function", type=eval, default=False)
    parser.add_argument("--tpu", type=str, default=const.TPU_NAME)
    parser.add_argument("--use-pooler", type=eval, default=True)
    parser.add_argument("--prefix", type=str, default='pretrained')

    global args
    args = parser.parse_args()
    model_type = os.path.basename(args.checkpoint_path)
    embeddings_path = os.path.join(args.embeddings_path, "shards-42031",
                                   model_type)
    args_dict = {**args.__dict__, "embeddings_path": embeddings_path}

    configs = ["{}: {}".format(k, v) for k, v in args_dict.items()]
    configs_string = "\t" + "\n\t".join(configs) + "\n"
    print("************************* Configurations *************************")
    print(configs_string)
    print(
        "----------------------------------------------------------------------------------------------------------------------"
    )

    config_path = "configs/{}/{}/config.yml".format(
        __file__.rstrip(".py"),
        datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
    write_config(config_path, args_dict)

    try:  # detect TPUs
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            tpu=args.tpu)  # TPU detection
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.TPUStrategy(resolver)
    except Exception:  # detect GPUs
        devices = tf.config.list_physical_devices("GPU")
        # [tf.config.experimental.set_memory_growth(device, True) for device in devices]
        if devices:
            strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
        else:
            strategy = tf.distribute.get_strategy()

    index_path = os.path.join(args.index_path, model_type)
    if not os.path.exists(index_path):
        os.makedirs(index_path)
    indexer = create_or_retrieve_indexer(index_path=index_path,
                                         embeddings_path=embeddings_path)
    # exit(0) # only create index
    question_encoder = load_checkpoint(checkpoint_path=args.checkpoint_path,
                                       strategy=strategy)
    questions, answers = load_qas_test_data()
    tokenizer = get_tokenizer(model_name=args.pretrained_model,
                              prefix=args.prefix)
    dataset = prepare_dataset(args.qas_tfrecord_path,
                              strategy=strategy,
                              tokenizer=tokenizer,
                              max_query_length=args.max_query_length)
    question_embeddings = generate_embeddings(
        question_encoder=question_encoder, dataset=dataset, strategy=strategy)
    top_ids_and_scores = search_knn(indexer=indexer,
                                    question_embeddings=question_embeddings)
    all_docs = load_ctx_sources()

    print("Validating... ")
    top_k_hits_path = os.path.join(args.result_path, model_type,
                                   "top_k_hits.txt")
    if not os.path.exists(os.path.dirname(top_k_hits_path)):
        os.makedirs(os.path.dirname(top_k_hits_path))

    start_time = time.perf_counter()
    questions_doc_hits = validate(top_ids_and_scores=top_ids_and_scores,
                                  answers=answers,
                                  ctx_sources=all_docs,
                                  top_k_hits_path=top_k_hits_path)
    print("done in {}s !".format(time.perf_counter() - start_time))
    print(
        "----------------------------------------------------------------------------------------------------------------------"
    )

    print("Generating reader data... ")
    reader_data_path = os.path.join(args.reader_data_path, model_type,
                                    "reader_data.json")
    if not os.path.exists(os.path.dirname(reader_data_path)):
        os.makedirs(os.path.dirname(reader_data_path))
    start_time = time.perf_counter()
    save_results(questions=questions,
                 answers=answers,
                 all_docs=all_docs,
                 top_passages_and_scores=top_ids_and_scores,
                 per_question_hits=questions_doc_hits,
                 out_file=reader_data_path)

    print("done in {}s !".format(time.perf_counter() - start_time))
    print(
        "----------------------------------------------------------------------------------------------------------------------"
    )
Ejemplo n.º 6
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--train-data-size", type=int, default=const.TRAIN_DATA_SIZE)
    parser.add_argument("--data-path", type=str, default=const.DATA_PATH, help="Path to the `.tfrecord` data. Data in this file is already preprocessed into tensor format")
    parser.add_argument("--max-context-length", type=int, default=const.MAX_CONTEXT_LENGTH, help="Maximum length of a document")
    parser.add_argument("--max-query-length", type=int, default=const.MAX_QUERY_LENGTH, help="Maximum length of a question")
    parser.add_argument("--batch-size", type=int, default=const.BATCH_SIZE, help="Batch size on each compute device")
    parser.add_argument("--epochs", type=int, default=const.EPOCHS)
    parser.add_argument("--learning-rate", type=float, default=const.LEARNING_RATE)
    parser.add_argument("--warmup-steps", type=int, default=const.WARMUP_STEPS)
    parser.add_argument("--adam-eps", type=float, default=const.ADAM_EPS)
    parser.add_argument("--adam-betas", type=eval, default=const.ADAM_BETAS)
    parser.add_argument("--weight-decay", type=float, default=const.WEIGHT_DECAY)
    parser.add_argument("--max-grad-norm", type=float, default=const.MAX_GRAD_NORM)
    parser.add_argument("--shuffle", type=eval, default=const.SHUFFLE)
    parser.add_argument("--seed", type=int, default=const.SHUFFLE_SEED)
    parser.add_argument("--checkpoint-path", type=str, default=const.CHECKPOINT_PATH)
    parser.add_argument("--ctx-encoder-trainable", type=eval, default=const.CTX_ENCODER_TRAINABLE, help="Whether the context encoder's weights are trainable")
    parser.add_argument("--question-encoder-trainable", type=eval, default=const.QUESTION_ENCODER_TRAINABLE, help="Whether the question encoder's weights are trainable")
    parser.add_argument("--tpu", type=str, default=const.TPU_NAME)
    parser.add_argument("--loss-fn", type=str, choices=['inbatch', 'threelevel', 'twolevel', 'hardnegvsneg', 'hardnegvsnegsoftmax', 'threelevelsoftmax'], default='threelevel')
    parser.add_argument("--use-pooler", type=eval, default=True)
    parser.add_argument("--load-optimizer", type=eval, default=True)
    parser.add_argument("--tokenizer", type=str, default="bert-base-uncased")
    parser.add_argument("--question-pretrained-model", type=str, default='bert-base-uncased')
    parser.add_argument("--context-pretrained-model", type=str, default='bigbird')
    parser.add_argument("--within-size", type=int, default=8)
    parser.add_argument("--prefix", type=str, default=None)

    global args
    args = parser.parse_args()
    args_dict = args.__dict__

    configs = ["{}: {}".format(k, v) for k, v in args_dict.items()]
    configs_string = "\t" + "\n\t".join(configs) + "\n"
    print("************************* Configurations *************************")
    print(configs_string)
    print("----------------------------------------------------------------------------------------------------------------------")

    file_name = os.path.basename(__file__)
    config_path = "configs/{}/{}/config.yml".format(file_name.rstrip(".py"), datetime.now().strftime("%Y-%m-%d %H-%M-%S"))
    write_config(config_path, args_dict)

    epochs = args.epochs

    global strategy
    try: # detect TPUs
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=args.tpu) # TPU detection
        tf.config.experimental_connect_to_cluster(resolver)
        tf.tpu.experimental.initialize_tpu_system(resolver)
        strategy = tf.distribute.TPUStrategy(resolver)
    except Exception: # detect GPUs
        devices = tf.config.list_physical_devices("GPU")
        # [tf.config.experimental.set_memory_growth(device, True) for device in devices]
        if devices:
            strategy = tf.distribute.MirroredStrategy()
        else:
            strategy = tf.distribute.get_strategy()

    tf.random.set_seed(args.seed)

    tokenizer = get_tokenizer(model_name=args.tokenizer, prefix=args.prefix)

    """Data pipeline
    1. Load retriever data (in `.tfrecord` format, stored serialized `tf.int32` tensor)
    2. Padding sequence to the same length
    3. Shuffle: You should shuffle before batch to guarantee that each data sample can be batched with different data samples in different epochs
    4. Repeating data: repeat to produce indefininte data stream
    5. Batching dataset
    6. Prefetching dataset (to speed up training) 
    """
    dataset = biencoder_manipulator.load_retriever_tfrecord_int_data(
        input_path=args.data_path,
        shuffle=args.shuffle,
        shuffle_seed=args.seed
    )
    dataset = biencoder_manipulator.pad(
        dataset, 
        sep_token_id=tokenizer.sep_token_id,
        max_context_length=args.max_context_length, 
        max_query_length=args.max_query_length
    )
    if args.loss_fn == 'inbatch':
        dataset = dataset.map(
            lambda x: {
                'question': x['question'],
                'contexts': x['contexts'][:2]
            },
            num_parallel_calls=tf.data.AUTOTUNE,
        )
    else:
        dataset = dataset.map(
            lambda x: {
                'question': x['question'],
                'contexts': x['contexts'][:args.within_size]
            }
        )
    dataset = dataset.shuffle(buffer_size=60000)
    dataset = dataset.repeat()
    dataset = dataset.batch(args.batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)

    """
    Distribute the dataset
    """
    dist_dataset = strategy.distribute_datasets_from_function(
        lambda _: dataset
    )
    iterator = iter(dist_dataset)

    """
    Set up for distributed training
    """
    steps_per_epoch = args.train_data_size // (args.batch_size * strategy.num_replicas_in_sync)
    global optimizer
    global loss_fn
    global retriever
    with strategy.scope():
        # Instantiate question encoder
        question_encoder = get_encoder(
            model_name=args.question_pretrained_model,
            args=args,
            trainable=args.question_encoder_trainable,
            prefix=args.prefix
        )

        # Instantiate context encoder
        context_encoder = get_encoder(
            model_name=args.context_pretrained_model,
            args=args,
            trainable=args.ctx_encoder_trainable,
            prefix=args.prefix
        )

        retriever = models.BiEncoder(
            question_model=question_encoder,
            ctx_model=context_encoder,
            use_pooler=args.use_pooler
        )

        # Instantiate the optimizer
        optimizer = optimizers.get_adamw(
            steps_per_epoch=steps_per_epoch,
            warmup_steps=args.warmup_steps,
            epochs=args.epochs,
            learning_rate=args.learning_rate,
            eps=args.adam_eps,
            beta_1=args.adam_betas[0],
            beta_2=args.adam_betas[1],
            weight_decay=args.weight_decay,
        )

        # Define loss function
        if args.loss_fn == 'threelevel':
            loss_fn = biencoder.ThreeLevelDPRLoss(batch_size=args.batch_size, within_size=args.within_size)
        elif args.loss_fn == 'twolevel':
            loss_fn = biencoder.TwoLevelDPRLoss(batch_size=args.batch_size, within_size=args.within_size)
        elif args.loss_fn == "hardnegvsneg":
            loss_fn = biencoder.HardNegVsNegDPRLoss(batch_size=args.batch_size, within_size=args.within_size)
        elif args.loss_fn == 'hardnegvsnegsoftmax':
            loss_fn = biencoder.HardNegVsNegSoftMaxDPRLoss(batch_size=args.batch_size, within_size=args.within_size)
        elif args.loss_fn == 'threelevelsoftmax':
            loss_fn = biencoder.ThreeLevelSoftMaxDPRLoss(batch_size=args.batch_size, within_size=args.within_size)
        else:
            loss_fn = biencoder.InBatchDPRLoss(batch_size=args.batch_size)


    """
    Distributed train step
    """
    dist_train_step = get_dist_train_step(
        model_name=args.context_pretrained_model
    )


    """
    Configure checkpoint
    """
    with strategy.scope():
        checkpoint_path = args.checkpoint_path
        ckpt = tf.train.Checkpoint(
            model=retriever,
            current_epoch=tf.Variable(0)
        )
        if not args.load_optimizer:
            tmp_optimizer = copy.deepcopy(optimizer)
            ckpt.optimizer = tmp_optimizer
        else:
            ckpt.optimizer = optimizer

        ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=3)

        # if a checkpoint exists, restore the latest checkpoint
        if ckpt_manager.latest_checkpoint:
            ckpt.restore(ckpt_manager.latest_checkpoint).expect_partial()
            current_epoch = ckpt.current_epoch.numpy()
            print("Latest checkpoint restored -- Model trained for {} epochs".format(current_epoch))
        else:
            print("Checkpoint not found. Train from scratch")
            current_epoch = 0
        
        if not args.load_optimizer:
            ckpt.optimizer = optimizer


    """
    Bootstrap
    """
    sample = next(iter(dist_dataset))
    dist_train_step(sample)


    """
    Training loop
    """
    for epoch in range(current_epoch, epochs):
        print("*************** Epoch {:02d}/{:02d} ***************".format(epoch + 1, epochs))
        begin_epoch_time = time.perf_counter()

        for step in range(steps_per_epoch):
            begin_step_time = time.perf_counter()
            loss = dist_train_step(next(iterator))
            print("Step {: <6d}Loss: {: <20f}Elapsed: {}".format(
                step + 1,
                loss.numpy(),
                time.perf_counter() - begin_step_time,
            ))

        print("\nEpoch's elapsed time: {}\n".format(time.perf_counter() - begin_epoch_time))

        ckpt.current_epoch.assign_add(1)
        
        # Checkpoint the model
        ckpt_save_path = ckpt_manager.save()
        print ('\nSaving checkpoint for epoch {} at {}'.format(epoch + 1, ckpt_save_path))