def training_run(bert_args, config, initializers, checkpoint_paths):
    logger.info("Building Model")
    model = Bert(config,
                 builder=popart.Builder(opsets={
                     "ai.onnx": 9,
                     "ai.onnx.ml": 1,
                     "ai.graphcore": 1
                 }),
                 initializers=initializers,
                 execution_mode=bert_args.execution_mode)

    indices, positions, segments, masks, labels = bert_add_inputs(
        bert_args, model)
    logits = bert_logits_graph(model, indices, positions, segments, masks,
                               bert_args.execution_mode)

    predictions, probs = bert_infer_graph(model, logits)
    losses = bert_loss_graph(model, probs, labels)
    outputs = bert_add_validation_outputs(model, predictions, losses)

    embedding_dict, positional_dict = model.get_model_embeddings()
    dataset = get_bert_dataset(model, bert_args,
                               [indices, positions, segments, masks, labels],
                               embedding_dict, positional_dict)

    data_flow = popart.DataFlow(dataset.batches_per_step, outputs)

    request_ipus, _ = calc_required_ipus(bert_args, model)
    device = acquire_device(bert_args, request_ipus)

    logger.info(f"Dataset length: {len(dataset)}")

    writer = bert_writer(bert_args)
    iteration = Iteration(
        bert_args,
        batches_per_step=dataset.batches_per_step,
        steps_per_epoch=len(dataset),
        writer=writer,
        recording_steps=bert_args.aggregate_metrics_over_steps)
    optimizer_factory = ScheduledOptimizerFactory(bert_args, iteration, "SGD",
                                                  model.tensors)
    session, anchors = bert_training_session(model, bert_args, data_flow,
                                             losses, device, optimizer_factory)

    for path in checkpoint_paths:
        ckpt_name = os.path.splitext(os.path.basename(path))[0]
        session.resetHostWeights(os.path.abspath(path))
        session.weightsFromHost()

        logger.info(f"Fine-tuning started for checkpoint: {path}")

        run_fine_tuning_store_ckpt(bert_args, model, ckpt_name, session,
                                   dataset, predictions, losses, labels,
                                   anchors)

    device.detach()
Example #2
0
def run_fine_tuning_store_ckpt(bert_args, model, ckpt_name, session, dataset,
                               predictions, losses, labels, anchors):

    writer = bert_writer(bert_args)
    iteration = Iteration(
        bert_args,
        batches_per_step=dataset.batches_per_step,
        steps_per_epoch=len(dataset),
        writer=writer,
        recording_steps=bert_args.aggregate_metrics_over_steps)
    optimizer_factory = ScheduledOptimizerFactory(bert_args, iteration,
                                                  model.tensors)

    for iteration.epoch in range(iteration.start_epoch, bert_args.epochs):
        for data in dataset:
            bert_process_data(bert_args, session, labels, data, anchors,
                              losses, predictions, iteration,
                              optimizer_factory)

    model_fn = os.path.join(bert_args.checkpoint_dir, "squad_output",
                            f"squad_final_{ckpt_name}.onnx")
    session.modelToHost(model_fn)