def test_nsp_fwd(custom_ops): # ------------------- PopART -------------------- config = BertConfig(task="NSP", vocab_length=9728, num_layers=2, micro_batch_size=1, hidden_size=768, sequence_length=128, activation_type="relu", popart_dtype="FLOAT", no_dropout=True, no_attn_dropout=True, inference=True, no_mask=True, mask_tokens=0, split_qkv=False) popart_model = Bert(config) # ------------------- PyTorch ------------------------- torch_model = BertForNextSentencePrediction( TorchBertConfig(config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act=config.activation_type, max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, mask_tokens=config.mask_tokens, num_labels=2)) fwd_graph(popart_model, torch_model, NSP_MAPPING, transform=NSP_TRANSFORM)
def test_nsp_fwd(custom_ops): # ------------------- PopART -------------------- builder = popart.Builder( opsets={"ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1}) config = BertConfig(task="NSP", vocab_length=9728, num_layers=2, batch_size=1, hidden_size=768, sequence_length=128, activation_type="relu", popart_dtype="FLOAT", no_dropout=True, custom_ops=["gather", "attention"], inference=True) popart_model = Bert(config, builder=builder) # ------------------- PyTorch ------------------------- torch_model = BertForNextSentencePrediction( TorchBertConfig(config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act=config.activation_type, max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, mask_tokens=config.mask_tokens, num_labels=2)) fwd_graph(popart_model, torch_model, mapping=NSP_MAPPING, transform=NSP_TRANSFORM)
def test_nsp_bwd(custom_ops): # ------------------- PopART -------------------- builder = popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }) config = BertConfig(task="NSP", vocab_length=9728, num_layers=1, batch_size=1, hidden_size=768, sequence_length=128, activation_type="relu", popart_dtype="FLOAT", no_dropout=True, custom_ops=["gather", "attention"]) popart_model = Bert(config, builder=builder) # ------------------- PyTorch ------------------------- torch_model = BertForNextSentencePrediction( TorchBertConfig(config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act="relu", max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, mask_tokens=config.mask_tokens, num_labels=2)) def popart_loss_fn(outputs): loss = popart.L1Loss(outputs[0], "l1Loss", 0.1) loss.virtualGraph(popart_model.nsp_scope.virtualGraph) return [loss] def torch_loss_fn(outputs): return 0.1 * torch.norm(outputs[0], 1) bwd_graph(popart_model, torch_model, popart_loss_fn=popart_loss_fn, torch_loss_fn=torch_loss_fn, mapping={ "bert.pooler.dense.weight": "NSP/PoolW", "bert.pooler.dense.bias": "NSP/PoolB", "cls.seq_relationship.weight": "NSP/NspW", "cls.seq_relationship.bias": "NSP/NspB" }, transform={ "bert.pooler.dense.weight": np.transpose, "cls.seq_relationship.weight": np.transpose })
def test_nsp_bwd(custom_ops, opt_type): # ------------------- PopART -------------------- config = BertConfig(task="NSP", vocab_length=2432, micro_batch_size=1, hidden_size=288, sequence_length=128, activation_type="relu", popart_dtype="FLOAT", no_dropout=True, no_attn_dropout=True, no_mask=True, update_embedding_dict=True, split_qkv=(opt_type == "LAMB")) popart_model = Bert(config) # ------------------- PyTorch ------------------------- torch_model = BertForNextSentencePrediction( TorchBertConfig(config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act=config.activation_type, max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, mask_tokens=config.mask_tokens, update_embedding_dict=True, num_labels=2)) l1_lambda = 0.1 def popart_loss_fn(outputs): loss = popart_model.builder.aiGraphcore.l1loss( [outputs[0]], l1_lambda, debugContext="l1LossVal", reduction=popart.ReductionType.Sum) popart_model.builder.virtualGraph(loss, popart_model.nsp_scope.virtualGraph) return loss def torch_loss_fn(outputs): return l1_lambda * torch.norm(outputs[0], 1) bwd_graph(popart_model, torch_model, popart_loss_fn=popart_loss_fn, torch_loss_fn=torch_loss_fn, mapping=NSP_MAPPING, transform=NSP_TRANSFORM, opt_type=opt_type)
def test_nsp_bwd(custom_ops): # ------------------- PopART -------------------- builder = popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }) config = BertConfig(task="NSP", vocab_length=9728, num_layers=1, batch_size=1, hidden_size=768, sequence_length=128, activation_type="relu", popart_dtype="FLOAT", no_dropout=True) popart_model = Bert(config, builder=builder) # ------------------- PyTorch ------------------------- torch_model = BertForNextSentencePrediction( TorchBertConfig(config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act=config.activation_type, max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, mask_tokens=config.mask_tokens, num_labels=2)) def popart_loss_fn(outputs): loss = builder.aiGraphcore.l1loss([outputs[0]], 0.1, debugPrefix="l1Loss", reduction=popart.ReductionType.Sum) builder.virtualGraph(loss, popart_model.nsp_scope.virtualGraph) return loss def torch_loss_fn(outputs): return 0.1 * torch.norm(outputs[0], 1) bwd_graph(popart_model, torch_model, popart_loss_fn=popart_loss_fn, torch_loss_fn=torch_loss_fn, mapping=NSP_MAPPING, transform=NSP_TRANSFORM)
def nsp_bwd(custom_ops, mode, opt_type, vocab_length=9728, hidden_size=768): if mode == ExecutionMode.PHASED: # Phased Execution requires atleast two transformer layers to ensure mlm and embedding are in the same virtual graph. num_layers = 2 else: num_layers = 1 # ------------------- PopART -------------------- config = BertConfig(task="NSP", vocab_length=vocab_length, num_layers=num_layers, batch_size=1, hidden_size=hidden_size, sequence_length=128, activation_type="relu", popart_dtype="FLOAT", no_dropout=True, no_attn_dropout=True, no_mask=True, update_embedding_dict=True, phased_execution_type="single", execution_mode=mode, split_qkv=(opt_type == "LAMB")) popart_model = get_model(config, mode) # ------------------- PyTorch ------------------------- torch_model = BertForNextSentencePrediction( TorchBertConfig(config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act=config.activation_type, max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, mask_tokens=config.mask_tokens, update_embedding_dict=True, num_labels=2)) l1_lambda = 0.1 def popart_loss_fn(outputs): if mode == ExecutionMode.PHASED: with popart_model.scope_provider(popart_model.builder, popart_model.nsp_scope): loss = popart_model.builder.aiGraphcore.l1loss( [outputs[0]], l1_lambda, debugPrefix="l1LossVal", reduction=popart.ReductionType.Sum) else: loss = popart_model.builder.aiGraphcore.l1loss( [outputs[0]], l1_lambda, debugPrefix="l1LossVal", reduction=popart.ReductionType.Sum) popart_model.builder.virtualGraph( loss, popart_model.nsp_scope.virtualGraph) return loss def torch_loss_fn(outputs): return l1_lambda * torch.norm(outputs[0], 1) bwd_graph(popart_model, torch_model, mode, popart_loss_fn=popart_loss_fn, torch_loss_fn=torch_loss_fn, mapping=NSP_MAPPING[mode], transform=NSP_TRANSFORM, opt_type=opt_type)