def transpose_checkpoint_embedding(checkpoint):
    print(f"Loading checkpoint: {checkpoint}")
    initializers = utils.load_initializers_from_onnx(checkpoint)
    model = bert.Bert(config,
                      initializers=initializers,
                      pipeline=args.pipeline)

    indices, positions, segments, masks, _ = bert.bert_add_inputs(args, model)
    bert.bert_logits_graph(model, indices, positions, segments, masks,
                           args.pipeline)

    proto = model.builder.getModelProto()
    onnx_proto = onnx.load_from_string(proto)

    output_path = os.path.join(pargs.output_dir, os.path.basename(checkpoint))
    print(f"Saving to: {output_path}")
    onnx.save(onnx_proto, output_path)
def training_run(bert_args, config, initializers, checkpoint_paths):
    logger.info("Building Model")
    model = Bert(config,
                 builder=popart.Builder(opsets={
                     "ai.onnx": 9,
                     "ai.onnx.ml": 1,
                     "ai.graphcore": 1
                 }),
                 initializers=initializers,
                 execution_mode=bert_args.execution_mode)

    indices, positions, segments, masks, labels = bert_add_inputs(
        bert_args, model)
    logits = bert_logits_graph(model, indices, positions, segments, masks,
                               bert_args.execution_mode)

    predictions, probs = bert_infer_graph(model, logits)
    losses = bert_loss_graph(model, probs, labels)
    outputs = bert_add_validation_outputs(model, predictions, losses)

    embedding_dict, positional_dict = model.get_model_embeddings()
    dataset = get_bert_dataset(model, bert_args,
                               [indices, positions, segments, masks, labels],
                               embedding_dict, positional_dict)

    data_flow = popart.DataFlow(dataset.batches_per_step, outputs)

    request_ipus, _ = calc_required_ipus(bert_args, model)
    device = acquire_device(bert_args, request_ipus)

    logger.info(f"Dataset length: {len(dataset)}")

    writer = bert_writer(bert_args)
    iteration = Iteration(
        bert_args,
        batches_per_step=dataset.batches_per_step,
        steps_per_epoch=len(dataset),
        writer=writer,
        recording_steps=bert_args.aggregate_metrics_over_steps)
    optimizer_factory = ScheduledOptimizerFactory(bert_args, iteration, "SGD",
                                                  model.tensors)
    session, anchors = bert_training_session(model, bert_args, data_flow,
                                             losses, device, optimizer_factory)

    for path in checkpoint_paths:
        ckpt_name = os.path.splitext(os.path.basename(path))[0]
        session.resetHostWeights(os.path.abspath(path))
        session.weightsFromHost()

        logger.info(f"Fine-tuning started for checkpoint: {path}")

        run_fine_tuning_store_ckpt(bert_args, model, ckpt_name, session,
                                   dataset, predictions, losses, labels,
                                   anchors)

    device.detach()
Example #3
0
def transpose_checkpoint_embedding(checkpoint):
    print(f"Loading checkpoint: {checkpoint}")
    initializers = utils.load_initializers_from_onnx(checkpoint)
    model = bert.Bert(config,
                      builder=popart.Builder(opsets={
                          "ai.onnx": 9,
                          "ai.onnx.ml": 1,
                          "ai.graphcore": 1
                      }),
                      initializers=initializers,
                      execution_mode=args.execution_mode)

    indices, positions, segments, masks, _ = bert.bert_add_inputs(args, model)
    bert.bert_logits_graph(model, indices, positions, segments, masks)

    proto = model.builder.getModelProto()
    onnx_proto = onnx.load_from_string(proto)

    output_path = os.path.join(pargs.output_dir, os.path.basename(checkpoint))
    print(f"Saving to: {output_path}")
    onnx.save(onnx_proto, output_path)
Example #4
0
def pooled_validation_run(bert_args,
                          config,
                          initializers,
                          checkpoint_paths,
                          num_processes=1,
                          available_ipus=16):
    logger.info("Building Model")
    model = Bert(config,
                 builder=popart.Builder(
                     opsets={"ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1}),
                 initializers=initializers)

    indices, positions, segments, masks, labels = bert_add_inputs(
        bert_args, model)
    logits = bert_logits_graph(model, indices, positions, segments, masks, bert_args.execution_mode)
    inputs = [indices, positions, segments, *masks]
    outputs = bert_add_logit_outputs(model, logits)

    with tempfile.TemporaryDirectory() as temp_results_path:
        # Inject the checkpoint-specific squad results directory into the dataset args otherwise
        # they overwrite each other when multithreaded
        bert_args.squad_results_dir = temp_results_path

        dataset = get_bert_dataset(
            model, bert_args, [indices, positions, segments, masks, labels])
        logger.info(f"Dataset length: {len(dataset)}")

        data_flow = popart.DataFlow(dataset.batches_per_step, outputs)

        iteration = Iteration(
            bert_args,
            batches_per_step=dataset.batches_per_step,
            steps_per_epoch=len(dataset),
            writer=None,
            recording_steps=bert_args.aggregate_metrics_over_steps)

        request_ipus, _ = calc_required_ipus(bert_args, model)

        if request_ipus * num_processes > available_ipus:
            raise ValueError(
                "Cannot run with requested number of processes - too many IPUs required")

        device = acquire_device(bert_args, request_ipus)

        session, anchors = bert_inference_session(
            model, bert_args, data_flow, device)

        model_results = recursive_defaultdict()
        for path in checkpoint_paths:
            session.resetHostWeights(str(path.absolute()))
            session.weightsFromHost()

            logger.info(f"Inference started for checkpoint: {path.absolute()}")
            result = run_inference_extract_result(bert_args,
                                                  session,
                                                  dataset,
                                                  inputs,
                                                  logits,
                                                  anchors,
                                                  iteration)

            result_into_recursive_path(model_results, path, bert_args.checkpoint_dir, result)

        device.detach()
    return model_results