def training_run(bert_args, config, initializers, checkpoint_paths):
    logger.info("Building Model")
    model = Bert(config,
                 builder=popart.Builder(opsets={
                     "ai.onnx": 9,
                     "ai.onnx.ml": 1,
                     "ai.graphcore": 1
                 }),
                 initializers=initializers,
                 execution_mode=bert_args.execution_mode)

    indices, positions, segments, masks, labels = bert_add_inputs(
        bert_args, model)
    logits = bert_logits_graph(model, indices, positions, segments, masks,
                               bert_args.execution_mode)

    predictions, probs = bert_infer_graph(model, logits)
    losses = bert_loss_graph(model, probs, labels)
    outputs = bert_add_validation_outputs(model, predictions, losses)

    embedding_dict, positional_dict = model.get_model_embeddings()
    dataset = get_bert_dataset(model, bert_args,
                               [indices, positions, segments, masks, labels],
                               embedding_dict, positional_dict)

    data_flow = popart.DataFlow(dataset.batches_per_step, outputs)

    request_ipus, _ = calc_required_ipus(bert_args, model)
    device = acquire_device(bert_args, request_ipus)

    logger.info(f"Dataset length: {len(dataset)}")

    writer = bert_writer(bert_args)
    iteration = Iteration(
        bert_args,
        batches_per_step=dataset.batches_per_step,
        steps_per_epoch=len(dataset),
        writer=writer,
        recording_steps=bert_args.aggregate_metrics_over_steps)
    optimizer_factory = ScheduledOptimizerFactory(bert_args, iteration, "SGD",
                                                  model.tensors)
    session, anchors = bert_training_session(model, bert_args, data_flow,
                                             losses, device, optimizer_factory)

    for path in checkpoint_paths:
        ckpt_name = os.path.splitext(os.path.basename(path))[0]
        session.resetHostWeights(os.path.abspath(path))
        session.weightsFromHost()

        logger.info(f"Fine-tuning started for checkpoint: {path}")

        run_fine_tuning_store_ckpt(bert_args, model, ckpt_name, session,
                                   dataset, predictions, losses, labels,
                                   anchors)

    device.detach()
Ejemplo n.º 2
0
def run_embedding_layer(args):
    set_library_seeds(args.seed)

    config = bert_config_from_args(args)

    initializers = bert_pretrained_initialisers(config, args)

    logger.info("Building Model")
    # Specifying ai.onnx opset9 for the slice syntax
    # TODO: Change slice to opset10
    model = Bert(config,
                 builder=popart.Builder(opsets={
                     "ai.onnx": 9,
                     "ai.onnx.ml": 1,
                     "ai.graphcore": 1
                 }),
                 initializers=initializers,
                 execution_mode=args.execution_mode)

    # If config.host_embedding is enabled, indices and positions will have the matrices instead of the index vector.
    indices, positions, segments, masks, labels = bert_add_inputs(args, model)
    logits = tuple([model.embedding(indices, positions, segments)])

    if args.inference:
        outputs = bert_add_infer_outputs(model, logits)
        losses = []
        writer = None
        embedding_dict, positional_dict = model.get_model_embeddings()

        dataset = get_bert_dataset(
            model, args, [indices, positions, segments, masks, labels],
            embedding_dict, positional_dict)

        data_flow = popart.DataFlow(dataset.batches_per_step, outputs)

        iteration = Iteration(
            args,
            batches_per_step=dataset.batches_per_step,
            steps_per_epoch=len(dataset),
            writer=writer,
            recording_steps=args.aggregate_metrics_over_steps)

        request_ipus, required_ipus = calc_required_ipus(args, model)

        device = acquire_device(args, request_ipus)

        session, anchors = bert_inference_session(model, args, data_flow,
                                                  losses, device)
        logger.info("Inference Started")
        inputs = [indices, positions, segments, *masks]
        """bert_infer_loop(args, session,
                        dataset, inputs, logits, anchors,
                        iteration)"""
        save_results = args.task == "SQUAD" and not args.synthetic_data

        start_times = defaultdict(list)
        end_times = defaultdict(list)
        # Create the stepio once outside of the inference loop:
        static_data = {}
        if args.low_latency_inference and args.task == "SQUAD":
            stepio = create_callback_stepio(static_data, anchors, start_times,
                                            end_times,
                                            dataset.batches_per_step)
        else:
            stepio = None

        enable_realtime_scheduling(args)

        output = []
        logger.info(dataset)
        for data in dataset:
            static_data.update({t: data[t] for t in inputs})
            result = bert_process_infer_data(args, session, static_data,
                                             anchors, logits, iteration,
                                             start_times, end_times, stepio)
            if save_results:
                output.append(result)
            break

        disable_realtime_scheduling(args)

        device.detach()
        return output

    return None
Ejemplo n.º 3
0
def pooled_validation_run(bert_args,
                          config,
                          initializers,
                          checkpoint_paths,
                          num_processes=1,
                          available_ipus=16):
    logger.info("Building Model")
    model = Bert(config,
                 builder=popart.Builder(
                     opsets={"ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1}),
                 initializers=initializers)

    indices, positions, segments, masks, labels = bert_add_inputs(
        bert_args, model)
    logits = bert_logits_graph(model, indices, positions, segments, masks, bert_args.execution_mode)
    inputs = [indices, positions, segments, *masks]
    outputs = bert_add_logit_outputs(model, logits)

    with tempfile.TemporaryDirectory() as temp_results_path:
        # Inject the checkpoint-specific squad results directory into the dataset args otherwise
        # they overwrite each other when multithreaded
        bert_args.squad_results_dir = temp_results_path

        dataset = get_bert_dataset(
            model, bert_args, [indices, positions, segments, masks, labels])
        logger.info(f"Dataset length: {len(dataset)}")

        data_flow = popart.DataFlow(dataset.batches_per_step, outputs)

        iteration = Iteration(
            bert_args,
            batches_per_step=dataset.batches_per_step,
            steps_per_epoch=len(dataset),
            writer=None,
            recording_steps=bert_args.aggregate_metrics_over_steps)

        request_ipus, _ = calc_required_ipus(bert_args, model)

        if request_ipus * num_processes > available_ipus:
            raise ValueError(
                "Cannot run with requested number of processes - too many IPUs required")

        device = acquire_device(bert_args, request_ipus)

        session, anchors = bert_inference_session(
            model, bert_args, data_flow, device)

        model_results = recursive_defaultdict()
        for path in checkpoint_paths:
            session.resetHostWeights(str(path.absolute()))
            session.weightsFromHost()

            logger.info(f"Inference started for checkpoint: {path.absolute()}")
            result = run_inference_extract_result(bert_args,
                                                  session,
                                                  dataset,
                                                  inputs,
                                                  logits,
                                                  anchors,
                                                  iteration)

            result_into_recursive_path(model_results, path, bert_args.checkpoint_dir, result)

        device.detach()
    return model_results