def test_pretraining_fwd(custom_ops): # ------------------- PopART -------------------- builder = popart.Builder(opsets={"ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1}) config = BertConfig(task="PRETRAINING", vocab_length=9728, num_layers=2, batch_size=1, hidden_size=768, sequence_length=128, popart_dtype="FLOAT", activation_type="relu", no_dropout=True, custom_ops=["gather", "attention"], inference=True) popart_model = Bert(config, builder=builder) # ------------------- PyTorch ------------------------- torch_model = BertForMaskedLM( TorchBertConfig(config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act="relu", max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, mask_tokens=config.mask_tokens)) fwd_graph(popart_model, torch_model, mapping=onnx_torch_mapping, transform=onnx_torch_tform)
def test_pretraining_bwd(custom_ops, opt_type): # ------------------- PopART -------------------- config = BertConfig(task="PRETRAINING", encoder_start_ipu=1, vocab_length=1024, micro_batch_size=1, hidden_size=64, attention_heads=2, sequence_length=20, max_positional_length=20, mask_tokens=2, popart_dtype="FLOAT", activation_type="relu", no_dropout=True, no_attn_dropout=True, update_embedding_dict=True, no_cls_layer=True, no_mask=True, split_qkv=(opt_type == "LAMB")) popart_model = Bert(config) # ------------------- PyTorch ------------------------- torch_model = BertForMaskedLM( TorchBertConfig(config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act="relu", max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, update_embedding_dict=True, mask_tokens=config.mask_tokens)) l1_lambda = 0.1 def popart_loss_fn(logits): loss = popart_model.builder.aiGraphcore.l1loss( [logits[0]], l1_lambda, debugContext="l1LossVal", reduction=popart.ReductionType.Sum) popart_model.builder.virtualGraph(loss, popart_model.mlm_scope.virtualGraph) return loss bwd_graph( popart_model, torch_model, popart_loss_fn=popart_loss_fn, torch_loss_fn=lambda logits: l1_lambda * torch.norm(logits[0], 1), mapping={}, transform=onnx_torch_tform, opt_type=opt_type)
def test_pretraining_bwd(custom_ops): # ------------------- PopART -------------------- builder = popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }) config = BertConfig(task="PRETRAINING", vocab_length=9728, projection_serialization_steps=4, num_layers=1, batch_size=1, hidden_size=768, sequence_length=128, popart_dtype="FLOAT", activation_type="relu", no_dropout=True, update_embedding_dict=False) popart_model = Bert(config, builder=builder) # ------------------- PyTorch ------------------------- torch_model = BertForMaskedLM( TorchBertConfig(config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act="relu", max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, mask_tokens=config.mask_tokens)) l1_lambda = 0.1 def popart_loss_fn(logits): loss = builder.aiGraphcore.l1loss([logits[0]], l1_lambda, debugPrefix="l1LossVal", reduction=popart.ReductionType.Sum) builder.virtualGraph(loss, popart_model.mlm_scope.virtualGraph) return loss bwd_graph( popart_model, torch_model, popart_loss_fn=popart_loss_fn, torch_loss_fn=lambda logits: l1_lambda * torch.norm(logits[0], 1), mapping=onnx_torch_mapping, transform=onnx_torch_tform)
def test_pretraining_fwd(custom_ops, mode, replication_factor, replicated_tensor_sharding): # ------------------- PopART -------------------- config = BertConfig(task="PRETRAINING", encoder_start_ipu=1, vocab_length=1024, micro_batch_size=1, hidden_size=64, attention_heads=2, sequence_length=20, max_positional_length=20, mask_tokens=2, popart_dtype="FLOAT", activation_type="relu", no_dropout=True, no_attn_dropout=True, no_cls_layer=False, inference=True, no_mask=True, execution_mode=mode, split_qkv=False) popart_model = get_model(config, mode) # ------------------- PyTorch ------------------------- torch_model = BertForMaskedLM( TorchBertConfig(config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act="relu", max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, mask_tokens=config.mask_tokens, no_cls_layer=config.no_cls_layer)) fwd_graph(popart_model, torch_model, mode, mapping=ONNX_TORCH_MAPPING[mode], transform=onnx_torch_tform, replication_factor=replication_factor, replicated_tensor_sharding=replicated_tensor_sharding)
def pretraining_bwd(custom_ops, mode, replication_factor, replicated_weight_sharding, opt_type, vocab_length=9728, hidden_size=768): # ------------------- PopART -------------------- if mode == ExecutionMode.PHASED: # Phased Execution requires atleast two transformer layers to ensure mlm and embedding are in the same virtual graph. num_layers = 2 else: num_layers = 1 config = BertConfig(task="PRETRAINING", vocab_length=vocab_length, num_layers=num_layers, batch_size=1, hidden_size=hidden_size, sequence_length=128, popart_dtype="FLOAT", activation_type="relu", no_dropout=True, no_attn_dropout=True, update_embedding_dict=True, no_cls_layer=True, no_mask=True, phased_execution_type="single", execution_mode=mode, split_qkv=(opt_type == "LAMB")) popart_model = get_model(config, mode) # ------------------- PyTorch ------------------------- torch_model = BertForMaskedLM( TorchBertConfig(config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act="relu", max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, update_embedding_dict=True, mask_tokens=config.mask_tokens)) l1_lambda = 0.1 def popart_loss_fn(logits): if mode == ExecutionMode.PHASED: with popart_model.scope_provider(popart_model.builder, popart_model.mlm_scope): loss = popart_model.builder.aiGraphcore.l1loss( [logits[0]], l1_lambda, debugPrefix="l1LossVal", reduction=popart.ReductionType.Sum) else: loss = popart_model.builder.aiGraphcore.l1loss( [logits[0]], l1_lambda, debugPrefix="l1LossVal", reduction=popart.ReductionType.Sum) popart_model.builder.virtualGraph( loss, popart_model.mlm_scope.virtualGraph) return loss bwd_graph( popart_model, torch_model, mode, popart_loss_fn=popart_loss_fn, torch_loss_fn=lambda logits: l1_lambda * torch.norm(logits[0], 1), mapping={}, transform=onnx_torch_tform, replication_factor=replication_factor, replicated_weight_sharding=replicated_weight_sharding, opt_type=opt_type)
def test_load_from_chkpt(config_path, chkpt_path, custom_ops): """ Compare the model loaded into our popart model against the modified PyTorch model: - Load tf weights into BERT using torch impl -> run fwd model - Load tf weights into BERT using popart impl -> run fwd model - Compare output tensors """ config = load_bert_config_tf(config_path) builder = popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }) # Load Torch version torch_model = TorchModel( TorchBertConfig( config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act="relu", max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, mask_tokens=config.mask_tokens, )) torch_model.eval() torch_model = load_tf_weights_in_bert(torch_model, config, chkpt_path) # Load Popart model sequence_info = popart.TensorInfo( "INT32", [config.batch_size * config.sequence_length]) indices = builder.addInputTensor(sequence_info) positions = builder.addInputTensor(sequence_info) popart_model, proto, output = load_from_tf(chkpt_path, True, config, indices, positions, builder=builder) # Run the models popart_inputs = { indices: np.random.randint(0, config.vocab_length, (config.batch_size * config.sequence_length)).astype( np.int32), positions: np.random.randint( 0, config.sequence_length, (config.batch_size * config.sequence_length), ).astype(np.int32), } torch_inputs = { "input_ids": popart_inputs[indices].reshape(config.batch_size, config.sequence_length), "position_ids": popart_inputs[positions].reshape(config.batch_size, config.sequence_length), } torch_outputs = run_fwd_model(torch_inputs, torch_model) popart_outputs, post_proto = run_py( proto, popart_inputs, output, ipus=math.ceil(config.num_layers / config.layers_per_ipu) + 1, ) check_tensors(torch_outputs, popart_outputs) print("Test succeeded")