def run_embedding_layer(args): set_library_seeds(args.seed) config = bert_config_from_args(args) initializers = bert_pretrained_initialisers(config, args) logger.info("Building Model") # Specifying ai.onnx opset9 for the slice syntax # TODO: Change slice to opset10 model = Bert(config, builder=popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }), initializers=initializers, execution_mode=args.execution_mode) # If config.host_embedding is enabled, indices and positions will have the matrices instead of the index vector. indices, positions, segments, masks, labels = bert_add_inputs(args, model) logits = tuple([model.embedding(indices, positions, segments)]) if args.inference: outputs = bert_add_logit_outputs(model, logits) writer = None dataset = get_bert_dataset( model, args, [indices, positions, segments, masks, labels]) data_flow = popart.DataFlow(dataset.batches_per_step, outputs) iteration = Iteration( args, steps_per_epoch=len(dataset), writer=writer, recording_steps=args.aggregate_metrics_over_steps) request_ipus = bert_required_ipus(args, model) device = acquire_device(args, request_ipus) session, anchors = bert_inference_session(model, args, data_flow, device) logger.info("Inference Started") inputs = [indices, positions, segments, *masks] """bert_infer_loop(args, session, dataset, inputs, logits, anchors, iteration)""" save_results = args.task == "SQUAD" and not (args.synthetic_data or args.generated_data) start_times = defaultdict(list) end_times = defaultdict(list) # Create the stepio once outside of the inference loop: static_data = {} if args.low_latency_inference and args.task == "SQUAD": stepio = create_callback_stepio(static_data, anchors, start_times, end_times, dataset.batches_per_step) else: stepio = None output = [] logger.info(dataset) for data in dataset: static_data.update({t: data[t] for t in inputs}) result = bert_process_infer_data(args, session, static_data, anchors, logits, iteration, start_times, end_times, stepio) if save_results: output.append(result) break device.detach() return output return None
def test_embedding_fwd(custom_ops): # ------------------- PopART -------------------- config = BertConfig(task="SQUAD", vocab_length=9728, micro_batch_size=1, hidden_size=768, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, inference=True) popart_model = Bert(config) sequence_info = popart.TensorInfo( "UINT32", [config.micro_batch_size * config.sequence_length]) indices = popart_model.builder.addInputTensor(sequence_info) positions = popart_model.builder.addInputTensor(sequence_info) segments = popart_model.builder.addInputTensor(sequence_info) data = { indices: np.random.randint( 0, config.vocab_length, (config.micro_batch_size * config.sequence_length)).astype( np.uint32), positions: np.random.randint( 0, config.max_positional_length, (config.micro_batch_size * config.sequence_length)).astype( np.uint32), segments: np.random.randint( 0, 2, (config.micro_batch_size * config.sequence_length)).astype( np.uint32) } user_options = {"enableStochasticRounding": True} with popart_model.builder.nameScope("Embedding"): output = popart_model.embedding(indices, positions, segments) proto = popart_model.builder.getModelProto() outputs, post_proto = run_py(proto, data, output, user_options=user_options) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = [ data[t].reshape(config.micro_batch_size, config.sequence_length).astype(np.int32) for t in [indices, positions, segments] ] # ------------------- PyTorch ------------------------- torch_model = BertEmbeddings( TorchBertConfig(config.vocab_length, config.hidden_size, max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps)) torch_model.eval() copy_weights_to_torch(torch_model, proto, TORCH_TO_ONNX, {}) torch_outputs = run_fwd_model(inputs, torch_model) check_tensors(torch_outputs, outputs, margin=5e-7)
def test_embedding_bwd(custom_ops): # ------------------- PopART -------------------- config = BertConfig(task="SQUAD", vocab_length=9728, micro_batch_size=1, hidden_size=768, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, update_embedding_dict=True) popart_model = Bert(config) # Prevent virtualGraph attributes being added to the ops sequence_info = popart.TensorInfo( "UINT32", [config.micro_batch_size * config.sequence_length]) indices = popart_model.builder.addInputTensor(sequence_info) positions = popart_model.builder.addInputTensor(sequence_info) segments = popart_model.builder.addInputTensor(sequence_info) data = { indices: np.random.randint( 0, config.vocab_length, (config.micro_batch_size * config.sequence_length)).astype( np.uint32), positions: np.random.randint( 0, config.max_positional_length, (config.micro_batch_size * config.sequence_length)).astype( np.uint32), segments: np.random.randint( 0, 2, (config.micro_batch_size * config.sequence_length)).astype( np.uint32) } optimizer = popart.ConstSGD(0.01) l1_lambda = 0.1 with popart_model.builder.nameScope("Embedding"): output = popart_model.embedding(indices, positions, segments) l1 = popart_model.builder.aiGraphcore.l1loss( [output], l1_lambda, debugContext="l1LossVal", reduction=popart.ReductionType.Sum) num_reps = 5 proto = popart_model.builder.getModelProto() outputs, post_proto = run_py(proto, data, output, ipus=1, loss=l1, num_reps=num_reps, optimizer=optimizer) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = [ data[t].reshape(config.micro_batch_size, config.sequence_length).astype(np.int32) for t in [indices, positions, segments] ] # ------------------- PyTorch ------------------------- torch_model = BertEmbeddings( TorchBertConfig(config.vocab_length, config.hidden_size, max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, update_embedding_dict=config.update_embedding_dict)) # Turn off dropout torch_model.eval() copy_weights_to_torch(torch_model, proto, TORCH_TO_ONNX, {}) optim = torch.optim.SGD(torch_model.parameters(), 0.01) for _ in range(num_reps): torch_output = torch_model( *[torch.from_numpy(t).long() for t in inputs]) torch_loss = l1_lambda * torch.norm(torch_output, 1) torch_loss.backward() optim.step() optim.zero_grad() torch_outputs = [torch_output.detach().numpy()] check_tensors(torch_outputs, outputs, margin=7e-6) check_model(torch_model, post_proto, TORCH_TO_ONNX, {}, margin=7e-06)
def test_embedding_fwd(custom_ops): # ------------------- PopART -------------------- builder = popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }) config = BertConfig(vocab_length=9728, batch_size=1, hidden_size=768, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, custom_ops=['gather'], inference=True) popart_model = Bert(config, builder=builder) # Prevent virtualGraph attributes being added to the ops. popart_model.embedding_scope = popart_model.device_scope(None, None) popart_model.embedding_split_scope = popart_model.embedding_scope sequence_info = popart.TensorInfo( "UINT32", [config.batch_size * config.sequence_length]) indices = builder.addInputTensor(sequence_info) positions = builder.addInputTensor(sequence_info) segments = builder.addInputTensor(sequence_info) data = { indices: np.random.randint(0, config.vocab_length, (config.batch_size * config.sequence_length)).astype( np.uint32), positions: np.random.randint(0, config.max_positional_length, (config.batch_size * config.sequence_length)).astype( np.uint32), segments: np.random.randint(0, 2, (config.batch_size * config.sequence_length)).astype( np.uint32) } # Use the custom embedding for layout output = popart_model.embedding(indices, positions, segments) proto = builder.getModelProto() outputs, post_proto = run_py( proto, data, output, user_options={"enableStochasticRounding": True}) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = [ data[t].reshape(config.batch_size, config.sequence_length).astype(np.int32) for t in [indices, positions, segments] ] torch_to_onnx = { "word_embeddings.weight": "Embedding_Dict", "position_embeddings.weight": "Positional_Dict", "token_type_embeddings.weight": "Segment_Dict", "LayerNorm.weight": "Gamma", "LayerNorm.bias": "Beta" } transposed_weights = { "word_embeddings.weight": np.transpose, "position_embeddings.weight": np.transpose, } # ------------------- PyTorch ------------------------- torch_model = BertEmbeddings( TorchBertConfig(config.vocab_length, config.hidden_size, max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps)) torch_model.eval() copy_weights_to_torch(torch_model, proto, torch_to_onnx, transposed_weights) torch_outputs = run_fwd_model(inputs, torch_model) check_tensors(torch_outputs, outputs)
def test_embedding_bwd(custom_ops): l1_lambda = 0.1 # ------------------- PopART -------------------- builder = popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }) config = BertConfig(vocab_length=9728, batch_size=1, hidden_size=768, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, custom_ops=['gather']) popart_model = Bert(config, builder=builder) # Prevent virtualGraph attributes being added to the ops. popart_model.embedding_scope = popart_model.device_scope(None, None) popart_model.embedding_split_scope = popart_model.embedding_scope sequence_info = popart.TensorInfo( "UINT32", [config.batch_size * config.sequence_length]) indices = builder.addInputTensor(sequence_info) positions = builder.addInputTensor(sequence_info) segments = builder.addInputTensor(sequence_info) data = { indices: np.random.randint(0, config.vocab_length, (config.batch_size * config.sequence_length)).astype( np.uint32), positions: np.random.randint(0, config.max_positional_length, (config.batch_size * config.sequence_length)).astype( np.uint32), segments: np.random.randint(0, 2, (config.batch_size * config.sequence_length)).astype( np.uint32) } output = popart_model.embedding(indices, positions, segments) proto = builder.getModelProto() l1 = popart.L1Loss(output, "l1LossVal", l1_lambda) optimizer = popart.ConstSGD(0.01) outputs, post_proto = run_py( proto, data, output, loss=l1, optimizer=optimizer, user_options={"enableStochasticRounding": True}) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = [ data[t].reshape(config.batch_size, config.sequence_length).astype(np.int32) for t in [indices, positions, segments] ] torch_to_onnx = { "word_embeddings.weight": "Embedding_Dict", "position_embeddings.weight": "Positional_Dict", "token_type_embeddings.weight": "Segment_Dict", "LayerNorm.weight": "Gamma", "LayerNorm.bias": "Beta" } transposed_weights = { "word_embeddings.weight": np.transpose, "position_embeddings.weight": np.transpose, } # ------------------- PyTorch ------------------------- torch_model = BertEmbeddings( TorchBertConfig(config.vocab_length, config.hidden_size, max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps)) # Turn off dropout torch_model.eval() copy_weights_to_torch(torch_model, proto, torch_to_onnx, transform=transposed_weights) optim = torch.optim.SGD(torch_model.parameters(), 0.01, weight_decay=0.0, momentum=0.0) torch_output = torch_model(*[torch.from_numpy(t).long() for t in inputs]) torch_loss = l1_lambda * torch.norm(torch_output, 1) torch_loss.backward() optim.step() torch_outputs = [torch_output.detach().numpy()] check_tensors(torch_outputs, outputs) check_model(torch_model, post_proto, torch_to_onnx, transform=transposed_weights)
def test_embedding_fwd(custom_ops): # ------------------- PopART -------------------- builder = popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }) config = BertConfig(task="SQUAD", vocab_length=9728, batch_size=1, hidden_size=768, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, inference=True) popart_model = Bert(config, builder=builder) # Prevent virtualGraph attributes being added to the ops. popart_model.embedding_scope = popart_model.device_scope(None, None) popart_model.embedding_split_scope = popart_model.embedding_scope sequence_info = popart.TensorInfo( "UINT32", [config.batch_size * config.sequence_length]) indices = builder.addInputTensor(sequence_info) positions = builder.addInputTensor(sequence_info) segments = builder.addInputTensor(sequence_info) data = { indices: np.random.randint(0, config.vocab_length, (config.batch_size * config.sequence_length)).astype( np.uint32), positions: np.random.randint(0, config.max_positional_length, (config.batch_size * config.sequence_length)).astype( np.uint32), segments: np.random.randint(0, 2, (config.batch_size * config.sequence_length)).astype( np.uint32) } output = popart_model.embedding(indices, positions, segments) proto = builder.getModelProto() outputs, post_proto = run_py(proto, data, output) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = [ data[t].reshape(config.batch_size, config.sequence_length).astype(np.int32) for t in [indices, positions, segments] ] # ------------------- PyTorch ------------------------- torch_model = BertEmbeddings( TorchBertConfig(config.vocab_length, config.hidden_size, max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps)) torch_model.eval() copy_weights_to_torch(torch_model, proto, TORCH_TO_ONNX, {}) torch_outputs = run_fwd_model(inputs, torch_model) check_tensors(torch_outputs, outputs)