def bert_pretrained_initialisers(config, args): if args.synthetic_data: logger.info("Initialising from synthetic_data") return None if args.generated_data: logger.info("Initialising from generated_data") return None # The initialised weights will be broadcast after the session has been created if not popdist_root(args): return None init = None if args.onnx_checkpoint: logger.info( f"Initialising from ONNX checkpoint: {args.onnx_checkpoint}") init = utils.load_initializers_from_onnx(args.onnx_checkpoint) if args.tf_checkpoint: logger.info(f"Initialising from TF checkpoint: {args.tf_checkpoint}") init = load_initializers_from_tf(args.tf_checkpoint, True, config, args.task) if init is not None: init.update(**get_phased_initializers_from_default(args, init)) return init
def bert_pretrained_initialisers(config, args): if args.synthetic_data: logger.info("Initialising from synthetic_data") return None if args.generated_data: logger.info("Initialising from generated_data") return None init = None if args.onnx_checkpoint: logger.info( f"Initialising from ONNX checkpoint: {args.onnx_checkpoint}") init = utils.load_initializers_from_onnx(args.onnx_checkpoint) if args.tf_checkpoint: logger.info(f"Initialising from TF checkpoint: {args.tf_checkpoint}") init = load_initializers_from_tf(args.tf_checkpoint, True, config, args.task) if init is not None: init.update(**get_phased_initializers_from_default(args, init)) return init
def bert_pretrained_initialisers(config, args): if args.synthetic_data: return None if args.onnx_checkpoint: logger.info(f"Initialising from ONNX checkpoint: {args.onnx_checkpoint}") return utils.load_initializers_from_onnx(args.onnx_checkpoint) if args.tf_checkpoint: logger.info(f"Initialising from TF checkpoint: {args.tf_checkpoint}") return load_initializers_from_tf(args.tf_checkpoint, True, config) return None
def transpose_checkpoint_embedding(checkpoint): print(f"Loading checkpoint: {checkpoint}") initializers = utils.load_initializers_from_onnx(checkpoint) model = bert.Bert(config, initializers=initializers, pipeline=args.pipeline) indices, positions, segments, masks, _ = bert.bert_add_inputs(args, model) bert.bert_logits_graph(model, indices, positions, segments, masks, args.pipeline) proto = model.builder.getModelProto() onnx_proto = onnx.load_from_string(proto) output_path = os.path.join(pargs.output_dir, os.path.basename(checkpoint)) print(f"Saving to: {output_path}") onnx.save(onnx_proto, output_path)
def test_weight_mapping(num_vocab_splits, task): config = BertConfig(task=task, vocab_length=1024, micro_batch_size=1, hidden_size=64, attention_heads=2, sequence_length=20, mask_tokens=8, popart_dtype="FLOAT", num_layers=2, no_mask=True, no_dropout=True, no_attn_dropout=True, embedding_serialization_vocab_steps=num_vocab_splits, inference=True) # Run pipelined BERT pipelined_proto = get_model_proto(config, mode=ExecutionMode.PIPELINE) # Extract weights with tempfile.TemporaryDirectory() as tmp: file_path = os.path.join(tmp, "model.onnx") onnx.save(pipelined_proto, file_path) initializers = load_initializers_from_onnx(file_path) initializers.update( **get_phased_initializers_from_default(config, initializers)) # Create phased_execution version of the model config_nosplit = config._replace(embedding_serialization_vocab_steps=1) phased_proto = get_model_proto(config, mode=ExecutionMode.PHASED, initializers=initializers) # Create a pipelined version of the model without any embedding split for the comparison pipelined_proto_nosplit = get_model_proto(config_nosplit, mode=ExecutionMode.PIPELINE, initializers=initializers) # Check inital protos match for pipelined vs phased_execution model check_onnx_model(pipelined_proto_nosplit, phased_proto, phased_to_default_mapping(config), phased_from_default_transform(config), allow_missing=False)
def transpose_checkpoint_embedding(checkpoint): print(f"Loading checkpoint: {checkpoint}") initializers = utils.load_initializers_from_onnx(checkpoint) model = bert.Bert(config, builder=popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }), initializers=initializers, execution_mode=args.execution_mode) indices, positions, segments, masks, _ = bert.bert_add_inputs(args, model) bert.bert_logits_graph(model, indices, positions, segments, masks) proto = model.builder.getModelProto() onnx_proto = onnx.load_from_string(proto) output_path = os.path.join(pargs.output_dir, os.path.basename(checkpoint)) print(f"Saving to: {output_path}") onnx.save(onnx_proto, output_path)
def get_initializers(proto): with tempfile.TemporaryDirectory() as tmp: file_path = os.path.join(tmp, "model.onnx") onnx.save(proto, file_path) return load_initializers_from_onnx(file_path)