def test_squad_fwd(): # ------------------- PopART -------------------- builder = popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }) config = BertConfig(task="SQUAD", vocab_length=9728, num_layers=2, batch_size=1, hidden_size=768, sequence_length=128, activation_type="relu", popart_dtype="FLOAT", no_dropout=True, custom_ops=[], inference=True) popart_model = Bert(config, builder=builder) # ------------------- PyTorch ------------------------- torch_model = BertForQuestionAnswering( TorchBertConfig(config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act="relu", max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, mask_tokens=config.mask_tokens, num_labels=2)) fwd_graph(popart_model, torch_model, mapping={ "cls.transform.dense.weight": "CLS/LMPredictionW", "cls.transform.dense.bias": "CLS/LMPredictionB", "cls.transform.LayerNorm.weight": "CLS/Gamma", "cls.transform.LayerNorm.bias": "CLS/Beta", "qa_outputs.weight": "Squad/SquadW", "qa_outputs.bias": "Squad/SquadB" }, transform={ "cls.transform.dense.weight": np.transpose, "qa_outputs.weight": np.transpose })
def test_squad_fwd(mode, replication_factor, replicated_weight_sharding): split_qkv = False # ------------------- PopART -------------------- config = BertConfig(task="SQUAD", vocab_length=9728, num_layers=2, batch_size=1, hidden_size=768, sequence_length=128, activation_type="relu", popart_dtype="FLOAT", no_dropout=True, no_attn_dropout=True, inference=True, no_mask=True, execution_mode=mode, split_qkv=split_qkv, squad_single_output=False) popart_model = get_model(config, mode) # ------------------- PyTorch ------------------------- torch_model = BertForQuestionAnswering( TorchBertConfig(config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act="relu", max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, mask_tokens=config.mask_tokens, num_labels=2)) fwd_graph(popart_model, torch_model, mode, mapping=ONNX_TORCH_MAPPING[mode], transform={"qa_outputs.weight": np.transpose}, replication_factor=replication_factor, replicated_weight_sharding=replicated_weight_sharding)
def test_squad_fwd(custom_ops): # ------------------- PopART -------------------- config = BertConfig(task="SQUAD", encoder_start_ipu=1, vocab_length=1024, micro_batch_size=1, hidden_size=64, attention_heads=2, sequence_length=20, max_positional_length=20, activation_type="relu", popart_dtype="FLOAT", no_dropout=True, no_attn_dropout=True, inference=True, no_mask=True, split_qkv=False, squad_single_output=False) popart_model = Bert(config) # ------------------- PyTorch ------------------------- torch_model = BertForQuestionAnswering( TorchBertConfig(config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act="relu", max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, mask_tokens=2, num_labels=2)) fwd_graph(popart_model, torch_model, mapping=ONNX_TORCH_MAPPING, transform={"qa_outputs.weight": np.transpose})
def test_squad_bwd(): # ------------------- PopART -------------------- builder = popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }) config = BertConfig(task="SQUAD", vocab_length=9728, num_layers=1, batch_size=1, hidden_size=768, sequence_length=128, activation_type="relu", popart_dtype="FLOAT", no_dropout=True, custom_ops=[], update_embedding_dict=False) popart_model = Bert(config, builder=builder) # ------------------- PyTorch ------------------------- torch_model = BertForQuestionAnswering( TorchBertConfig(config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act="relu", max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, mask_tokens=config.mask_tokens, num_labels=2)) l1_lambda = 0.1 def popart_loss_fn(outputs): losses = [ popart.L1Loss(outputs[0], "startsLossVal", l1_lambda), popart.L1Loss(outputs[1], "endsLossVal", l1_lambda), ] for loss in losses: loss.virtualGraph(popart_model.squad_scope.virtualGraph) return losses def torch_loss_fn(outputs): torch_losses = [ l1_lambda * torch.norm(output, 1) for output in outputs ] return torch.add(*torch_losses) bwd_graph(popart_model, torch_model, popart_loss_fn=popart_loss_fn, torch_loss_fn=torch_loss_fn, mapping={ "cls.transform.dense.weight": "CLS/LMPredictionW", "cls.transform.dense.bias": "CLS/LMPredictionB", "cls.transform.LayerNorm.weight": "CLS/Gamma", "cls.transform.LayerNorm.bias": "CLS/Beta", "qa_outputs.weight": "Squad/SquadW", "qa_outputs.bias": "Squad/SquadB" }, transform={ "cls.transform.dense.weight": np.transpose, "qa_outputs.weight": np.transpose })
def squad_bwd(mode, replication_factor, replicated_weight_sharding, opt_type, vocab_length=9728, hidden_size=768): # ------------------- PopART -------------------- config = BertConfig(task="SQUAD", vocab_length=vocab_length, num_layers=1, batch_size=1, hidden_size=hidden_size, sequence_length=128, activation_type="relu", popart_dtype="FLOAT", no_dropout=True, no_attn_dropout=True, update_embedding_dict=True, no_mask=True, execution_mode=mode, split_qkv=(opt_type == "LAMB")) popart_model = get_model(config, mode) # ------------------- PyTorch ------------------------- torch_model = BertForQuestionAnswering( TorchBertConfig(config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act="relu", max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, mask_tokens=config.mask_tokens, update_embedding_dict=True, num_labels=2)) l1_lambda = 0.1 def popart_loss_fn(outputs): if mode == ExecutionMode.PHASED: with popart_model.scope_provider(popart_model.builder, popart_model.squad_scope): losses = [ popart_model.builder.aiGraphcore.l1loss( [outputs[0]], l1_lambda, debugPrefix="startsLossVal", reduction=popart.ReductionType.Sum), popart_model.builder.aiGraphcore.l1loss( [outputs[1]], l1_lambda, debugPrefix="endsLossVal", reduction=popart.ReductionType.Sum), ] final_loss = popart_model.builder.aiOnnx.sum( losses, debugPrefix="finalLoss") else: losses = [ popart_model.builder.aiGraphcore.l1loss( [outputs[0]], l1_lambda, debugPrefix="startsLossVal", reduction=popart.ReductionType.Sum), popart_model.builder.aiGraphcore.l1loss( [outputs[1]], l1_lambda, debugPrefix="endsLossVal", reduction=popart.ReductionType.Sum), ] for loss in losses: popart_model.builder.virtualGraph( loss, popart_model.squad_scope.virtualGraph) final_loss = popart_model.builder.aiOnnx.sum( losses, debugPrefix="finalLoss") popart_model.builder.virtualGraph( final_loss, popart_model.squad_scope.virtualGraph) return final_loss def torch_loss_fn(outputs): torch_losses = [ l1_lambda * torch.norm(output, 1) for output in outputs ] return torch.add(*torch_losses) bwd_graph(popart_model, torch_model, mode, popart_loss_fn=popart_loss_fn, torch_loss_fn=torch_loss_fn, mapping=ONNX_TORCH_MAPPING[mode], transform={"qa_outputs.weight": np.transpose}, replication_factor=replication_factor, replicated_weight_sharding=replicated_weight_sharding, opt_type=opt_type)
def test_squad_bwd(custom_ops, replication_factor, replicated_tensor_sharding, opt_type): # ------------------- PopART -------------------- config = BertConfig(task="SQUAD", num_layers=2, encoder_start_ipu=1, vocab_length=1024, micro_batch_size=1, hidden_size=64, attention_heads=2, sequence_length=20, activation_type="relu", popart_dtype="FLOAT", no_dropout=True, no_attn_dropout=True, update_embedding_dict=True, no_mask=True, split_qkv=(opt_type == "LAMB")) popart_model = Bert(config) # ------------------- PyTorch ------------------------- torch_model = BertForQuestionAnswering( TorchBertConfig(config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act="relu", max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, mask_tokens=2, update_embedding_dict=True, num_labels=2)) l1_lambda = 0.1 def popart_loss_fn(outputs): losses = [ popart_model.builder.aiGraphcore.l1loss( [outputs[0]], l1_lambda, debugContext="startsLossVal", reduction=popart.ReductionType.Sum), popart_model.builder.aiGraphcore.l1loss( [outputs[1]], l1_lambda, debugContext="endsLossVal", reduction=popart.ReductionType.Sum), ] for loss in losses: popart_model.builder.virtualGraph( loss, popart_model.squad_scope.virtualGraph) final_loss = popart_model.builder.aiOnnx.sum(losses, debugContext="finalLoss") popart_model.builder.virtualGraph( final_loss, popart_model.squad_scope.virtualGraph) return final_loss def torch_loss_fn(outputs): torch_losses = [ l1_lambda * torch.norm(output, 1) for output in outputs ] return torch.add(*torch_losses) bwd_graph(popart_model, torch_model, popart_loss_fn=popart_loss_fn, torch_loss_fn=torch_loss_fn, mapping=ONNX_TORCH_MAPPING, transform={"qa_outputs.weight": np.transpose}, replication_factor=replication_factor, replicated_tensor_sharding=replicated_tensor_sharding, opt_type=opt_type)