Exemplo n.º 1
0
def test_linear_schedule_with_continue_from_epoch(start_epoch, warmup, steps_per_decay, learning_rate):
    """Make sure the optimiser restarts the schedule from the correct point when resuming training
    from a given epoch"""

    config = TestConfig(
        continue_training_from_epoch=start_epoch,
        epochs=10,
        learning_rate_function="Linear",
        lr_warmup_steps=warmup,
        lr_steps_per_decay_update=steps_per_decay,
        batches_per_step=10,
        learning_rate=learning_rate)

    iteration = Iteration(
        config,
        steps_per_epoch=20,
        writer=None,
        recording_steps=1)

    def generate_step_truth():
        if warmup > 0:
            schedule = linear_schedule(0, warmup, 1, 1e-7, learning_rate)
        else:
            schedule = {}
        schedule.update(linear_schedule(warmup, iteration.total_steps, steps_per_decay, learning_rate, 1e-7))
        return schedule
    schedule = generate_step_truth()
    step_to_match = list(filter(lambda k: k <= iteration.count, schedule.keys()))[-1]
    expected = schedule[step_to_match]

    factory = LinearOptimizerFactory(config, iteration)
    lr = factory.option_values["defaultLearningRate"]
    assert lr == expected
def training_run(bert_args, config, initializers, checkpoint_paths):
    logger.info("Building Model")
    model = Bert(config,
                 builder=popart.Builder(opsets={
                     "ai.onnx": 9,
                     "ai.onnx.ml": 1,
                     "ai.graphcore": 1
                 }),
                 initializers=initializers,
                 execution_mode=bert_args.execution_mode)

    indices, positions, segments, masks, labels = bert_add_inputs(
        bert_args, model)
    logits = bert_logits_graph(model, indices, positions, segments, masks,
                               bert_args.execution_mode)

    predictions, probs = bert_infer_graph(model, logits)
    losses = bert_loss_graph(model, probs, labels)
    outputs = bert_add_validation_outputs(model, predictions, losses)

    embedding_dict, positional_dict = model.get_model_embeddings()
    dataset = get_bert_dataset(model, bert_args,
                               [indices, positions, segments, masks, labels],
                               embedding_dict, positional_dict)

    data_flow = popart.DataFlow(dataset.batches_per_step, outputs)

    request_ipus, _ = calc_required_ipus(bert_args, model)
    device = acquire_device(bert_args, request_ipus)

    logger.info(f"Dataset length: {len(dataset)}")

    writer = bert_writer(bert_args)
    iteration = Iteration(
        bert_args,
        batches_per_step=dataset.batches_per_step,
        steps_per_epoch=len(dataset),
        writer=writer,
        recording_steps=bert_args.aggregate_metrics_over_steps)
    optimizer_factory = ScheduledOptimizerFactory(bert_args, iteration, "SGD",
                                                  model.tensors)
    session, anchors = bert_training_session(model, bert_args, data_flow,
                                             losses, device, optimizer_factory)

    for path in checkpoint_paths:
        ckpt_name = os.path.splitext(os.path.basename(path))[0]
        session.resetHostWeights(os.path.abspath(path))
        session.weightsFromHost()

        logger.info(f"Fine-tuning started for checkpoint: {path}")

        run_fine_tuning_store_ckpt(bert_args, model, ckpt_name, session,
                                   dataset, predictions, losses, labels,
                                   anchors)

    device.detach()
Exemplo n.º 3
0
def test_schedule_with_continue_from_epoch(start_epoch, steps_per_epoch, num_epochs, lr_schedule, expected):

    config = TestConfig(**{
        "continue_training_from_epoch": start_epoch,
        "epochs": num_epochs,
        "lr_schedule_by_step": lr_schedule
    })

    iteration = Iteration(
        config,
        batches_per_step=10,
        steps_per_epoch=steps_per_epoch,
        writer=None,
        recording_steps=1)

    factory = ScheduledOptimizerFactory(config, iteration)
    lr = factory.option_values["defaultLearningRate"]
    assert(lr == expected)
Exemplo n.º 4
0
def test_schedule_with_continue_from_epoch(start_epoch, steps_per_epoch,
                                           num_epochs, lr_schedule, expected):
    """Make sure the optimiser restarts the schedule from the correct point when resuming training
    from a given epoch"""

    config = TestConfig(
        **{
            "continue_training_from_epoch": start_epoch,
            "epochs": num_epochs,
            "lr_schedule_by_step": lr_schedule,
            "batches_per_step": 10
        })

    iteration = Iteration(config,
                          steps_per_epoch=steps_per_epoch,
                          writer=None,
                          recording_steps=1)

    factory = ScheduledOptimizerFactory(config, iteration)
    lr = factory.option_values["defaultLearningRate"]
    assert lr == expected
Exemplo n.º 5
0
def run_fine_tuning_store_ckpt(bert_args, model, ckpt_name, session, dataset,
                               predictions, losses, labels, anchors):

    writer = bert_writer(bert_args)
    iteration = Iteration(
        bert_args,
        batches_per_step=dataset.batches_per_step,
        steps_per_epoch=len(dataset),
        writer=writer,
        recording_steps=bert_args.aggregate_metrics_over_steps)
    optimizer_factory = ScheduledOptimizerFactory(bert_args, iteration,
                                                  model.tensors)

    for iteration.epoch in range(iteration.start_epoch, bert_args.epochs):
        for data in dataset:
            bert_process_data(bert_args, session, labels, data, anchors,
                              losses, predictions, iteration,
                              optimizer_factory)

    model_fn = os.path.join(bert_args.checkpoint_dir, "squad_output",
                            f"squad_final_{ckpt_name}.onnx")
    session.modelToHost(model_fn)
Exemplo n.º 6
0
def run_embedding_layer(args):
    set_library_seeds(args.seed)

    config = bert_config_from_args(args)

    initializers = bert_pretrained_initialisers(config, args)

    logger.info("Building Model")
    # Specifying ai.onnx opset9 for the slice syntax
    # TODO: Change slice to opset10
    model = Bert(config,
                 builder=popart.Builder(opsets={
                     "ai.onnx": 9,
                     "ai.onnx.ml": 1,
                     "ai.graphcore": 1
                 }),
                 initializers=initializers,
                 execution_mode=args.execution_mode)

    # If config.host_embedding is enabled, indices and positions will have the matrices instead of the index vector.
    indices, positions, segments, masks, labels = bert_add_inputs(args, model)
    logits = tuple([model.embedding(indices, positions, segments)])

    if args.inference:
        outputs = bert_add_logit_outputs(model, logits)
        writer = None

        dataset = get_bert_dataset(
            model, args, [indices, positions, segments, masks, labels])

        data_flow = popart.DataFlow(dataset.batches_per_step, outputs)

        iteration = Iteration(
            args,
            steps_per_epoch=len(dataset),
            writer=writer,
            recording_steps=args.aggregate_metrics_over_steps)

        request_ipus = bert_required_ipus(args, model)

        device = acquire_device(args, request_ipus)

        session, anchors = bert_inference_session(model, args, data_flow,
                                                  device)
        logger.info("Inference Started")
        inputs = [indices, positions, segments, *masks]
        """bert_infer_loop(args, session,
                        dataset, inputs, logits, anchors,
                        iteration)"""
        save_results = args.task == "SQUAD" and not (args.synthetic_data
                                                     or args.generated_data)

        start_times = defaultdict(list)
        end_times = defaultdict(list)
        # Create the stepio once outside of the inference loop:
        static_data = {}
        if args.low_latency_inference and args.task == "SQUAD":
            stepio = create_callback_stepio(static_data, anchors, start_times,
                                            end_times,
                                            dataset.batches_per_step)
        else:
            stepio = None

        output = []
        logger.info(dataset)
        for data in dataset:
            static_data.update({t: data[t] for t in inputs})
            result = bert_process_infer_data(args, session, static_data,
                                             anchors, logits, iteration,
                                             start_times, end_times, stepio)
            if save_results:
                output.append(result)
            break

        device.detach()
        return output

    return None
Exemplo n.º 7
0
def pooled_validation_run(bert_args,
                          config,
                          initializers,
                          checkpoint_paths,
                          num_processes=1,
                          available_ipus=16):
    logger.info("Building Model")
    model = Bert(config,
                 builder=popart.Builder(
                     opsets={"ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1}),
                 initializers=initializers)

    indices, positions, segments, masks, labels = bert_add_inputs(
        bert_args, model)
    logits = bert_logits_graph(model, indices, positions, segments, masks, bert_args.execution_mode)
    inputs = [indices, positions, segments, *masks]
    outputs = bert_add_logit_outputs(model, logits)

    with tempfile.TemporaryDirectory() as temp_results_path:
        # Inject the checkpoint-specific squad results directory into the dataset args otherwise
        # they overwrite each other when multithreaded
        bert_args.squad_results_dir = temp_results_path

        dataset = get_bert_dataset(
            model, bert_args, [indices, positions, segments, masks, labels])
        logger.info(f"Dataset length: {len(dataset)}")

        data_flow = popart.DataFlow(dataset.batches_per_step, outputs)

        iteration = Iteration(
            bert_args,
            batches_per_step=dataset.batches_per_step,
            steps_per_epoch=len(dataset),
            writer=None,
            recording_steps=bert_args.aggregate_metrics_over_steps)

        request_ipus, _ = calc_required_ipus(bert_args, model)

        if request_ipus * num_processes > available_ipus:
            raise ValueError(
                "Cannot run with requested number of processes - too many IPUs required")

        device = acquire_device(bert_args, request_ipus)

        session, anchors = bert_inference_session(
            model, bert_args, data_flow, device)

        model_results = recursive_defaultdict()
        for path in checkpoint_paths:
            session.resetHostWeights(str(path.absolute()))
            session.weightsFromHost()

            logger.info(f"Inference started for checkpoint: {path.absolute()}")
            result = run_inference_extract_result(bert_args,
                                                  session,
                                                  dataset,
                                                  inputs,
                                                  logits,
                                                  anchors,
                                                  iteration)

            result_into_recursive_path(model_results, path, bert_args.checkpoint_dir, result)

        device.detach()
    return model_results
Exemplo n.º 8
0
def test_iteration_stats(task, epochs, seed, exponent_loss, exponent_acc):
    """Assuming aggregate-metrics-over-steps is 1, the data recorded by the iteration and sent to the
    TB logger should exactly match the true data. Here we mock everything out, bar the iteration class
    itself and ensure this is the case."""

    random.seed(seed)

    total_data_length = 5000
    batches_per_step = 500
    recording_steps = 1  # --aggregate-metrics-over-steps

    dataset_length = int(total_data_length / batches_per_step)  # len(dataset)
    if total_data_length % batches_per_step != 0:
        raise ValueError(
            "Dataset not divisible by bps, not supported in this test.")

    args = MockArgs(**{"epochs": epochs, "task": task})
    mock_writer = MockWriter()

    def generate_known_exp_curve(task, epochs, dataset_length, gamma):
        # Force the losses to be slightly different in the case of two losses, to make sure we account
        # for expected differences  between them
        multipliers = (0.9, 1.05) if task == "PRETRAINING" else (1, )
        return {
            step: [m * math.exp(gamma * step) for m in multipliers]
            for step in range(int(epochs * dataset_length))
        }

    step_losses = generate_known_exp_curve(task, epochs, dataset_length,
                                           exponent_loss)
    step_accuracies = generate_known_exp_curve(task, epochs, dataset_length,
                                               exponent_acc)

    def generate_step_result(step_num):
        duration = random.random()
        hw_cycles = random.randint(1000, 2000)
        return duration, hw_cycles, step_losses[step_num], step_accuracies[
            step_num]

    def mock_stats_fn(loss, accuracy):
        return loss, accuracy

    iteration = Iteration(args, batches_per_step, dataset_length, mock_writer,
                          recording_steps)
    iteration.stats_fn = mock_stats_fn

    epoch_steps = []
    for iteration.epoch in range(args.start_epoch, args.epochs):
        epoch_steps.append(iteration.count)
        for data in range(dataset_length):
            step_result = generate_step_result(iteration.count)
            iteration.add_stats(*step_result)
            iteration.count += 1

    if task == "PRETRAINING":
        loss_keys = ["loss/MLM", "loss/NSP"]
        acc_keys = ["accuracy/MLM", "accuracy/NSP"]
    else:
        loss_keys = ["loss"]
        acc_keys = ["accuracy"]

    def compare_against_true_curve(keys, true_values, writer):
        for key_index, loss_key in enumerate(keys):
            for step_num, loss in writer.scalars[loss_key].items():
                isclose = np.isclose(true_values[step_num][key_index], loss)
                assert isclose

    compare_against_true_curve(loss_keys, step_losses, mock_writer)
    compare_against_true_curve(acc_keys, step_accuracies, mock_writer)