def training_run(bert_args, config, initializers, checkpoint_paths): logger.info("Building Model") model = Bert(config, builder=popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }), initializers=initializers, execution_mode=bert_args.execution_mode) indices, positions, segments, masks, labels = bert_add_inputs( bert_args, model) logits = bert_logits_graph(model, indices, positions, segments, masks, bert_args.execution_mode) predictions, probs = bert_infer_graph(model, logits) losses = bert_loss_graph(model, probs, labels) outputs = bert_add_validation_outputs(model, predictions, losses) embedding_dict, positional_dict = model.get_model_embeddings() dataset = get_bert_dataset(model, bert_args, [indices, positions, segments, masks, labels], embedding_dict, positional_dict) data_flow = popart.DataFlow(dataset.batches_per_step, outputs) request_ipus, _ = calc_required_ipus(bert_args, model) device = acquire_device(bert_args, request_ipus) logger.info(f"Dataset length: {len(dataset)}") writer = bert_writer(bert_args) iteration = Iteration( bert_args, batches_per_step=dataset.batches_per_step, steps_per_epoch=len(dataset), writer=writer, recording_steps=bert_args.aggregate_metrics_over_steps) optimizer_factory = ScheduledOptimizerFactory(bert_args, iteration, "SGD", model.tensors) session, anchors = bert_training_session(model, bert_args, data_flow, losses, device, optimizer_factory) for path in checkpoint_paths: ckpt_name = os.path.splitext(os.path.basename(path))[0] session.resetHostWeights(os.path.abspath(path)) session.weightsFromHost() logger.info(f"Fine-tuning started for checkpoint: {path}") run_fine_tuning_store_ckpt(bert_args, model, ckpt_name, session, dataset, predictions, losses, labels, anchors) device.detach()
def run_embedding_layer(args): set_library_seeds(args.seed) config = bert_config_from_args(args) initializers = bert_pretrained_initialisers(config, args) logger.info("Building Model") # Specifying ai.onnx opset9 for the slice syntax # TODO: Change slice to opset10 model = Bert(config, builder=popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }), initializers=initializers, execution_mode=args.execution_mode) # If config.host_embedding is enabled, indices and positions will have the matrices instead of the index vector. indices, positions, segments, masks, labels = bert_add_inputs(args, model) logits = tuple([model.embedding(indices, positions, segments)]) if args.inference: outputs = bert_add_logit_outputs(model, logits) writer = None dataset = get_bert_dataset( model, args, [indices, positions, segments, masks, labels]) data_flow = popart.DataFlow(dataset.batches_per_step, outputs) iteration = Iteration( args, steps_per_epoch=len(dataset), writer=writer, recording_steps=args.aggregate_metrics_over_steps) request_ipus = bert_required_ipus(args, model) device = acquire_device(args, request_ipus) session, anchors = bert_inference_session(model, args, data_flow, device) logger.info("Inference Started") inputs = [indices, positions, segments, *masks] """bert_infer_loop(args, session, dataset, inputs, logits, anchors, iteration)""" save_results = args.task == "SQUAD" and not (args.synthetic_data or args.generated_data) start_times = defaultdict(list) end_times = defaultdict(list) # Create the stepio once outside of the inference loop: static_data = {} if args.low_latency_inference and args.task == "SQUAD": stepio = create_callback_stepio(static_data, anchors, start_times, end_times, dataset.batches_per_step) else: stepio = None output = [] logger.info(dataset) for data in dataset: static_data.update({t: data[t] for t in inputs}) result = bert_process_infer_data(args, session, static_data, anchors, logits, iteration, start_times, end_times, stepio) if save_results: output.append(result) break device.detach() return output return None
def pooled_validation_run(bert_args, config, initializers, checkpoint_paths, num_processes=1, available_ipus=16): logger.info("Building Model") model = Bert(config, builder=popart.Builder( opsets={"ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1}), initializers=initializers) indices, positions, segments, masks, labels = bert_add_inputs( bert_args, model) logits = bert_logits_graph(model, indices, positions, segments, masks, bert_args.execution_mode) inputs = [indices, positions, segments, *masks] outputs = bert_add_logit_outputs(model, logits) with tempfile.TemporaryDirectory() as temp_results_path: # Inject the checkpoint-specific squad results directory into the dataset args otherwise # they overwrite each other when multithreaded bert_args.squad_results_dir = temp_results_path dataset = get_bert_dataset( model, bert_args, [indices, positions, segments, masks, labels]) logger.info(f"Dataset length: {len(dataset)}") data_flow = popart.DataFlow(dataset.batches_per_step, outputs) iteration = Iteration( bert_args, batches_per_step=dataset.batches_per_step, steps_per_epoch=len(dataset), writer=None, recording_steps=bert_args.aggregate_metrics_over_steps) request_ipus, _ = calc_required_ipus(bert_args, model) if request_ipus * num_processes > available_ipus: raise ValueError( "Cannot run with requested number of processes - too many IPUs required") device = acquire_device(bert_args, request_ipus) session, anchors = bert_inference_session( model, bert_args, data_flow, device) model_results = recursive_defaultdict() for path in checkpoint_paths: session.resetHostWeights(str(path.absolute())) session.weightsFromHost() logger.info(f"Inference started for checkpoint: {path.absolute()}") result = run_inference_extract_result(bert_args, session, dataset, inputs, logits, anchors, iteration) result_into_recursive_path(model_results, path, bert_args.checkpoint_dir, result) device.detach() return model_results