def validate_checkpoints(self_args, args): set_library_seeds(args.seed) config = bert_config_from_args(args) checkpoint_paths = [p for p in Path( args.checkpoint_dir).rglob(self_args.model_search_string)] if len(checkpoint_paths) < 1: raise FileNotFoundError( f"Did not recursively find any checkpoints at path: {args.checkpoint_dir}") # Load an initial model to configure the IO tensors and the session args.onnx_checkpoint = checkpoint_paths[0] initializers = bert_pretrained_initialisers(config, args) results = perform_validations(self_args.num_processes, checkpoint_paths, args, config, initializers, self_args.available_ipus) with open(os.path.join(self_args.checkpoint_dir, "validation_result.json"), 'w') as fh: json.dump(results, fh, indent=4) return results
def finetune_checkpoints(self_args, args): set_library_seeds(args.seed) config = bert_config_from_args(args) checkpoint_paths = glob.glob( os.path.join(self_args.checkpoint_dir, self_args.model_search_string)) checkpoint_paths.sort() os.makedirs(os.path.join(self_args.checkpoint_dir, "squad_output"), exist_ok=True) if len(checkpoint_paths) < 1: raise FileNotFoundError( f"Did not find any checkpoints at path: {args.checkpoint_dir}") # Load an initial model to configure the IO tensors and the session args.onnx_checkpoint = checkpoint_paths[0] initializers = bert_pretrained_initialisers(config, args) training_run(args, config, initializers, checkpoint_paths) logger.info("Fine-Tuning Complete")
def run_embedding_layer(args): set_library_seeds(args.seed) config = bert_config_from_args(args) initializers = bert_pretrained_initialisers(config, args) logger.info("Building Model") # Specifying ai.onnx opset9 for the slice syntax # TODO: Change slice to opset10 model = Bert(config, builder=popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }), initializers=initializers, execution_mode=args.execution_mode) # If config.host_embedding is enabled, indices and positions will have the matrices instead of the index vector. indices, positions, segments, masks, labels = bert_add_inputs(args, model) logits = tuple([model.embedding(indices, positions, segments)]) if args.inference: outputs = bert_add_logit_outputs(model, logits) writer = None dataset = get_bert_dataset( model, args, [indices, positions, segments, masks, labels]) data_flow = popart.DataFlow(dataset.batches_per_step, outputs) iteration = Iteration( args, steps_per_epoch=len(dataset), writer=writer, recording_steps=args.aggregate_metrics_over_steps) request_ipus = bert_required_ipus(args, model) device = acquire_device(args, request_ipus) session, anchors = bert_inference_session(model, args, data_flow, device) logger.info("Inference Started") inputs = [indices, positions, segments, *masks] """bert_infer_loop(args, session, dataset, inputs, logits, anchors, iteration)""" save_results = args.task == "SQUAD" and not (args.synthetic_data or args.generated_data) start_times = defaultdict(list) end_times = defaultdict(list) # Create the stepio once outside of the inference loop: static_data = {} if args.low_latency_inference and args.task == "SQUAD": stepio = create_callback_stepio(static_data, anchors, start_times, end_times, dataset.batches_per_step) else: stepio = None output = [] logger.info(dataset) for data in dataset: static_data.update({t: data[t] for t in inputs}) result = bert_process_infer_data(args, session, static_data, anchors, logits, iteration, start_times, end_times, stepio) if save_results: output.append(result) break device.detach() return output return None
indices, positions, segments, masks, _ = bert.bert_add_inputs(args, model) bert.bert_logits_graph(model, indices, positions, segments, masks) proto = model.builder.getModelProto() onnx_proto = onnx.load_from_string(proto) output_path = os.path.join(pargs.output_dir, os.path.basename(checkpoint)) print(f"Saving to: {output_path}") onnx.save(onnx_proto, output_path) if __name__ == "__main__": pargs, rem = parse_args() if "--config" not in rem: print( "Please specify the target config using the '--config' parameter (e.g. --config configs/squad_base_384.json" ) sys.exit(1) args = utils.parse_bert_args(rem) config = bert.bert_config_from_args(args) os.makedirs(pargs.output_dir, exist_ok=True) checkpoint_files = glob.glob( os.path.join(pargs.checkpoint_dir, pargs.model_search_string)) pool = multiprocessing.Pool(pargs.num_processes) pool.map(transpose_checkpoint_embedding, checkpoint_files)