def test_linear_schedule_with_continue_from_epoch(start_epoch, warmup, steps_per_decay, learning_rate): """Make sure the optimiser restarts the schedule from the correct point when resuming training from a given epoch""" config = TestConfig( continue_training_from_epoch=start_epoch, epochs=10, learning_rate_function="Linear", lr_warmup_steps=warmup, lr_steps_per_decay_update=steps_per_decay, batches_per_step=10, learning_rate=learning_rate) iteration = Iteration( config, steps_per_epoch=20, writer=None, recording_steps=1) def generate_step_truth(): if warmup > 0: schedule = linear_schedule(0, warmup, 1, 1e-7, learning_rate) else: schedule = {} schedule.update(linear_schedule(warmup, iteration.total_steps, steps_per_decay, learning_rate, 1e-7)) return schedule schedule = generate_step_truth() step_to_match = list(filter(lambda k: k <= iteration.count, schedule.keys()))[-1] expected = schedule[step_to_match] factory = LinearOptimizerFactory(config, iteration) lr = factory.option_values["defaultLearningRate"] assert lr == expected
def training_run(bert_args, config, initializers, checkpoint_paths): logger.info("Building Model") model = Bert(config, builder=popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }), initializers=initializers, execution_mode=bert_args.execution_mode) indices, positions, segments, masks, labels = bert_add_inputs( bert_args, model) logits = bert_logits_graph(model, indices, positions, segments, masks, bert_args.execution_mode) predictions, probs = bert_infer_graph(model, logits) losses = bert_loss_graph(model, probs, labels) outputs = bert_add_validation_outputs(model, predictions, losses) embedding_dict, positional_dict = model.get_model_embeddings() dataset = get_bert_dataset(model, bert_args, [indices, positions, segments, masks, labels], embedding_dict, positional_dict) data_flow = popart.DataFlow(dataset.batches_per_step, outputs) request_ipus, _ = calc_required_ipus(bert_args, model) device = acquire_device(bert_args, request_ipus) logger.info(f"Dataset length: {len(dataset)}") writer = bert_writer(bert_args) iteration = Iteration( bert_args, batches_per_step=dataset.batches_per_step, steps_per_epoch=len(dataset), writer=writer, recording_steps=bert_args.aggregate_metrics_over_steps) optimizer_factory = ScheduledOptimizerFactory(bert_args, iteration, "SGD", model.tensors) session, anchors = bert_training_session(model, bert_args, data_flow, losses, device, optimizer_factory) for path in checkpoint_paths: ckpt_name = os.path.splitext(os.path.basename(path))[0] session.resetHostWeights(os.path.abspath(path)) session.weightsFromHost() logger.info(f"Fine-tuning started for checkpoint: {path}") run_fine_tuning_store_ckpt(bert_args, model, ckpt_name, session, dataset, predictions, losses, labels, anchors) device.detach()
def test_schedule_with_continue_from_epoch(start_epoch, steps_per_epoch, num_epochs, lr_schedule, expected): config = TestConfig(**{ "continue_training_from_epoch": start_epoch, "epochs": num_epochs, "lr_schedule_by_step": lr_schedule }) iteration = Iteration( config, batches_per_step=10, steps_per_epoch=steps_per_epoch, writer=None, recording_steps=1) factory = ScheduledOptimizerFactory(config, iteration) lr = factory.option_values["defaultLearningRate"] assert(lr == expected)
def test_schedule_with_continue_from_epoch(start_epoch, steps_per_epoch, num_epochs, lr_schedule, expected): """Make sure the optimiser restarts the schedule from the correct point when resuming training from a given epoch""" config = TestConfig( **{ "continue_training_from_epoch": start_epoch, "epochs": num_epochs, "lr_schedule_by_step": lr_schedule, "batches_per_step": 10 }) iteration = Iteration(config, steps_per_epoch=steps_per_epoch, writer=None, recording_steps=1) factory = ScheduledOptimizerFactory(config, iteration) lr = factory.option_values["defaultLearningRate"] assert lr == expected
def run_fine_tuning_store_ckpt(bert_args, model, ckpt_name, session, dataset, predictions, losses, labels, anchors): writer = bert_writer(bert_args) iteration = Iteration( bert_args, batches_per_step=dataset.batches_per_step, steps_per_epoch=len(dataset), writer=writer, recording_steps=bert_args.aggregate_metrics_over_steps) optimizer_factory = ScheduledOptimizerFactory(bert_args, iteration, model.tensors) for iteration.epoch in range(iteration.start_epoch, bert_args.epochs): for data in dataset: bert_process_data(bert_args, session, labels, data, anchors, losses, predictions, iteration, optimizer_factory) model_fn = os.path.join(bert_args.checkpoint_dir, "squad_output", f"squad_final_{ckpt_name}.onnx") session.modelToHost(model_fn)
def run_embedding_layer(args): set_library_seeds(args.seed) config = bert_config_from_args(args) initializers = bert_pretrained_initialisers(config, args) logger.info("Building Model") # Specifying ai.onnx opset9 for the slice syntax # TODO: Change slice to opset10 model = Bert(config, builder=popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }), initializers=initializers, execution_mode=args.execution_mode) # If config.host_embedding is enabled, indices and positions will have the matrices instead of the index vector. indices, positions, segments, masks, labels = bert_add_inputs(args, model) logits = tuple([model.embedding(indices, positions, segments)]) if args.inference: outputs = bert_add_logit_outputs(model, logits) writer = None dataset = get_bert_dataset( model, args, [indices, positions, segments, masks, labels]) data_flow = popart.DataFlow(dataset.batches_per_step, outputs) iteration = Iteration( args, steps_per_epoch=len(dataset), writer=writer, recording_steps=args.aggregate_metrics_over_steps) request_ipus = bert_required_ipus(args, model) device = acquire_device(args, request_ipus) session, anchors = bert_inference_session(model, args, data_flow, device) logger.info("Inference Started") inputs = [indices, positions, segments, *masks] """bert_infer_loop(args, session, dataset, inputs, logits, anchors, iteration)""" save_results = args.task == "SQUAD" and not (args.synthetic_data or args.generated_data) start_times = defaultdict(list) end_times = defaultdict(list) # Create the stepio once outside of the inference loop: static_data = {} if args.low_latency_inference and args.task == "SQUAD": stepio = create_callback_stepio(static_data, anchors, start_times, end_times, dataset.batches_per_step) else: stepio = None output = [] logger.info(dataset) for data in dataset: static_data.update({t: data[t] for t in inputs}) result = bert_process_infer_data(args, session, static_data, anchors, logits, iteration, start_times, end_times, stepio) if save_results: output.append(result) break device.detach() return output return None
def pooled_validation_run(bert_args, config, initializers, checkpoint_paths, num_processes=1, available_ipus=16): logger.info("Building Model") model = Bert(config, builder=popart.Builder( opsets={"ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1}), initializers=initializers) indices, positions, segments, masks, labels = bert_add_inputs( bert_args, model) logits = bert_logits_graph(model, indices, positions, segments, masks, bert_args.execution_mode) inputs = [indices, positions, segments, *masks] outputs = bert_add_logit_outputs(model, logits) with tempfile.TemporaryDirectory() as temp_results_path: # Inject the checkpoint-specific squad results directory into the dataset args otherwise # they overwrite each other when multithreaded bert_args.squad_results_dir = temp_results_path dataset = get_bert_dataset( model, bert_args, [indices, positions, segments, masks, labels]) logger.info(f"Dataset length: {len(dataset)}") data_flow = popart.DataFlow(dataset.batches_per_step, outputs) iteration = Iteration( bert_args, batches_per_step=dataset.batches_per_step, steps_per_epoch=len(dataset), writer=None, recording_steps=bert_args.aggregate_metrics_over_steps) request_ipus, _ = calc_required_ipus(bert_args, model) if request_ipus * num_processes > available_ipus: raise ValueError( "Cannot run with requested number of processes - too many IPUs required") device = acquire_device(bert_args, request_ipus) session, anchors = bert_inference_session( model, bert_args, data_flow, device) model_results = recursive_defaultdict() for path in checkpoint_paths: session.resetHostWeights(str(path.absolute())) session.weightsFromHost() logger.info(f"Inference started for checkpoint: {path.absolute()}") result = run_inference_extract_result(bert_args, session, dataset, inputs, logits, anchors, iteration) result_into_recursive_path(model_results, path, bert_args.checkpoint_dir, result) device.detach() return model_results
def test_iteration_stats(task, epochs, seed, exponent_loss, exponent_acc): """Assuming aggregate-metrics-over-steps is 1, the data recorded by the iteration and sent to the TB logger should exactly match the true data. Here we mock everything out, bar the iteration class itself and ensure this is the case.""" random.seed(seed) total_data_length = 5000 batches_per_step = 500 recording_steps = 1 # --aggregate-metrics-over-steps dataset_length = int(total_data_length / batches_per_step) # len(dataset) if total_data_length % batches_per_step != 0: raise ValueError( "Dataset not divisible by bps, not supported in this test.") args = MockArgs(**{"epochs": epochs, "task": task}) mock_writer = MockWriter() def generate_known_exp_curve(task, epochs, dataset_length, gamma): # Force the losses to be slightly different in the case of two losses, to make sure we account # for expected differences between them multipliers = (0.9, 1.05) if task == "PRETRAINING" else (1, ) return { step: [m * math.exp(gamma * step) for m in multipliers] for step in range(int(epochs * dataset_length)) } step_losses = generate_known_exp_curve(task, epochs, dataset_length, exponent_loss) step_accuracies = generate_known_exp_curve(task, epochs, dataset_length, exponent_acc) def generate_step_result(step_num): duration = random.random() hw_cycles = random.randint(1000, 2000) return duration, hw_cycles, step_losses[step_num], step_accuracies[ step_num] def mock_stats_fn(loss, accuracy): return loss, accuracy iteration = Iteration(args, batches_per_step, dataset_length, mock_writer, recording_steps) iteration.stats_fn = mock_stats_fn epoch_steps = [] for iteration.epoch in range(args.start_epoch, args.epochs): epoch_steps.append(iteration.count) for data in range(dataset_length): step_result = generate_step_result(iteration.count) iteration.add_stats(*step_result) iteration.count += 1 if task == "PRETRAINING": loss_keys = ["loss/MLM", "loss/NSP"] acc_keys = ["accuracy/MLM", "accuracy/NSP"] else: loss_keys = ["loss"] acc_keys = ["accuracy"] def compare_against_true_curve(keys, true_values, writer): for key_index, loss_key in enumerate(keys): for step_num, loss in writer.scalars[loss_key].items(): isclose = np.isclose(true_values[step_num][key_index], loss) assert isclose compare_against_true_curve(loss_keys, step_losses, mock_writer) compare_against_true_curve(acc_keys, step_accuracies, mock_writer)