def test_split_qkv_weight_loading(): config = BertConfig(task="SQUAD", vocab_length=1024, micro_batch_size=1, hidden_size=64, attention_heads=2, sequence_length=20, popart_dtype="FLOAT", num_layers=2, no_mask=True, split_qkv=False) def get_split(full_t, t): return np.split(full_t, 3, axis=1)["QKV".index(t)] mapping = {f"Layer{i}/Attention/{t}": f"Layer{i}/Attention/QKV" for i in range(config.num_layers) for t in "QKV"} transform = {f"Layer{i}/Attention/{t}": partial(get_split, t=t) for i in range(config.num_layers) for t in "QKV"} # Get a unsplit checkpoint np.random.seed(123) proto_1 = get_model_proto(config) initializers = get_initializers(proto_1) split_config = config._replace(split_qkv=True) # Load the unsplit checkpoint into a split model np.random.seed(456) proto_2 = get_model_proto(split_config, initializers=initializers) check_onnx_model(proto_1, proto_2, mapping, transform, allow_missing=False) # Extract weights initializers = get_initializers(proto_2) # Load the split checkpoint into an unsplit model np.random.seed(456) proto_3 = get_model_proto(config, initializers=initializers) check_onnx_model(proto_3, proto_2, mapping, transform, allow_missing=False)
def test_simplified_position_encoding(position_length, hidden_size): builder = popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }) config = BertConfig(vocab_length=9728, batch_size=1, hidden_size=hidden_size, max_positional_length=position_length, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, positional_embedding_init_fn="SIMPLIFIED", inference=True) popart_model = Bert(config, builder=builder) shape = (config.max_positional_length, config.hidden_size) pa_data = popart_model.generate_simplified_periodic_pos_data( config.dtype, shape) bb_data = simplified_generator(position_length, hidden_size) assert (np.all(np.abs(bb_data - pa_data) < 1e-8))
def test_pretraining_fwd(custom_ops): # ------------------- PopART -------------------- builder = popart.Builder(opsets={"ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1}) config = BertConfig(task="PRETRAINING", vocab_length=9728, num_layers=2, batch_size=1, hidden_size=768, sequence_length=128, popart_dtype="FLOAT", activation_type="relu", no_dropout=True, custom_ops=["gather", "attention"], inference=True) popart_model = Bert(config, builder=builder) # ------------------- PyTorch ------------------------- torch_model = BertForMaskedLM( TorchBertConfig(config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act="relu", max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, mask_tokens=config.mask_tokens)) fwd_graph(popart_model, torch_model, mapping=onnx_torch_mapping, transform=onnx_torch_tform)
def test_positional_encoding_data(position_length, hidden_size): if not tf.executing_eagerly(): tf.enable_eager_execution() assert (tf.executing_eagerly()) builder = popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }) config = BertConfig(vocab_length=9728, batch_size=1, hidden_size=hidden_size, max_positional_length=position_length, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, positional_embedding_init_fn="TRANSFORMER", inference=True) popart_model = Bert(config, builder=builder) shape = (config.max_positional_length, config.hidden_size) pos_pa = popart_model.generate_transformer_periodic_pos_data( config.dtype, shape) pos_tf = get_position_encoding_tf(shape[0], shape[1]).numpy() # Tensorflow broadcast multiplication seems to produce slightly different results # to numpy, hence the higher than expected error. The embeddings do correlate well # between the two despite this. assert (np.all(np.abs(pos_tf - pos_pa) < 5e-5))
def load_bert_config_tf(config_path, override_vocab=None): """ Load the bert config data from Google Research's checkpoint format into the Popart Bert config format. """ import json with open(config_path, "r") as fh: config_data = json.load(fh) config = BertConfig( vocab_length=config_data["vocab_size"] if override_vocab is None else override_vocab, hidden_size=config_data["hidden_size"], sequence_length=config_data["max_position_embeddings"], max_positional_length=config_data["max_position_embeddings"], ff_size__=config_data["intermediate_size"], attention_heads=config_data["num_attention_heads"], num_layers=config_data["num_hidden_layers"], # TODO: Read the rest of these in from a GC config? projection_serialization_steps=4, batch_size=1, popart_dtype="FLOAT", no_dropout=True, inference=True, activation_type="relu", custom_ops=["gather", "attention"] ) return config
def test_nsp_fwd(custom_ops): # ------------------- PopART -------------------- config = BertConfig(task="NSP", vocab_length=9728, num_layers=2, micro_batch_size=1, hidden_size=768, sequence_length=128, activation_type="relu", popart_dtype="FLOAT", no_dropout=True, no_attn_dropout=True, inference=True, no_mask=True, mask_tokens=0, split_qkv=False) popart_model = Bert(config) # ------------------- PyTorch ------------------------- torch_model = BertForNextSentencePrediction( TorchBertConfig(config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act=config.activation_type, max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, mask_tokens=config.mask_tokens, num_labels=2)) fwd_graph(popart_model, torch_model, NSP_MAPPING, transform=NSP_TRANSFORM)
def create_dataset(args): # a simple copy of main bert.py until the dataset creation config = BertConfig() model = Bert(config, builder=popart.Builder()) indices, positions, segments, masks, labels = bert_add_inputs( args, model) inputs = [indices, positions, segments, masks, labels] embedding_dict, positional_dict = model.get_model_embeddings() shapeOf = model.builder.getTensorShape inputs = reduce(chain, inputs[3:], inputs[:3]) tensor_shapes = [(tensorId, shapeOf(tensorId)) for tensorId in inputs] dataset = get_bert_dataset(tensor_shapes, input_file=args.input_files, output_dir=args.output_dir, sequence_length=args.sequence_length, vocab_file=args.vocab_file, vocab_length=args.vocab_length, batch_size=args.batch_size, batches_per_step=args.batches_per_step, embedding_dict=embedding_dict, positional_dict=positional_dict, generated_data=args.generated_data, is_training=False, no_drop_remainder=True, shuffle=args.shuffle, mpi_size=args.mpi_size, is_distributed=(args.mpi_size > 1)) return dataset
def test_nsp_fwd(custom_ops): # ------------------- PopART -------------------- builder = popart.Builder( opsets={"ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1}) config = BertConfig(task="NSP", vocab_length=9728, num_layers=2, batch_size=1, hidden_size=768, sequence_length=128, activation_type="relu", popart_dtype="FLOAT", no_dropout=True, custom_ops=["gather", "attention"], inference=True) popart_model = Bert(config, builder=builder) # ------------------- PyTorch ------------------------- torch_model = BertForNextSentencePrediction( TorchBertConfig(config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act=config.activation_type, max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, mask_tokens=config.mask_tokens, num_labels=2)) fwd_graph(popart_model, torch_model, mapping=NSP_MAPPING, transform=NSP_TRANSFORM)
def test_split_embedding(custom_ops, weight_transposed, phase): """Test serialised embedding. Args: weight_transposed (bool): If True, weights are constructed transposed for the embedding layer. phase (str): Fwd pass or backward pass. custom_ops : Custom op module. """ np.random.seed(1984) config = BertConfig(vocab_length=4864, micro_batch_size=1, hidden_size=4096, sequence_length=128, popart_dtype="FLOAT", no_dropout=True, embedding_serialization_vocab_steps=num_splits) data, outputs, proto, post_proto = popart_result_and_model( config, weight_transposed, is_bwd=(phase == 'bwd')) inputs = [ t.reshape(config.micro_batch_size, config.sequence_length).astype(np.int32) for t in data ] torch_output, torch_model = pytorch_result_and_model( config, inputs, proto, weight_transposed, is_bwd=(phase == 'bwd')) check_tensors(torch_output, outputs) if phase == 'bwd': initializers = get_initializers(post_proto, weight_transposed) for name, weight in torch_model.named_parameters(): check_tensors(weight.data.numpy(), initializers[name])
def test_weight_mapping(num_vocab_splits, task): config = BertConfig(task=task, vocab_length=1024, micro_batch_size=1, hidden_size=64, attention_heads=2, sequence_length=20, mask_tokens=8, popart_dtype="FLOAT", num_layers=2, no_mask=True, no_dropout=True, no_attn_dropout=True, embedding_serialization_vocab_steps=num_vocab_splits, inference=True) # Run pipelined BERT pipelined_proto = get_model_proto(config, mode=ExecutionMode.PIPELINE) # Extract weights with tempfile.TemporaryDirectory() as tmp: file_path = os.path.join(tmp, "model.onnx") onnx.save(pipelined_proto, file_path) initializers = load_initializers_from_onnx(file_path) initializers.update( **get_phased_initializers_from_default(config, initializers)) # Create phased_execution version of the model config_nosplit = config._replace(embedding_serialization_vocab_steps=1) phased_proto = get_model_proto(config, mode=ExecutionMode.PHASED, initializers=initializers) # Create a pipelined version of the model without any embedding split for the comparison pipelined_proto_nosplit = get_model_proto(config_nosplit, mode=ExecutionMode.PIPELINE, initializers=initializers) # Check inital protos match for pipelined vs phased_execution model check_onnx_model(pipelined_proto_nosplit, phased_proto, phased_to_default_mapping(config), phased_from_default_transform(config), allow_missing=False)
def test_trainable_params(): config = BertConfig(task="PRETRAINING", vocab_length=1024, micro_batch_size=1, hidden_size=64, attention_heads=2, sequence_length=20, mask_tokens=8, popart_dtype="FLOAT", num_layers=2, no_mask=True, no_dropout=True, no_attn_dropout=True, embedding_serialization_vocab_steps=4, inference=False) # Create phased_execution version of the model model = get_model(config, ExecutionMode.PHASED) data = { 'indices': np.random.randint( 0, config.vocab_length, (config.micro_batch_size * config.sequence_length)).astype( np.uint32), 'positions': np.random.randint( 0, config.sequence_length, (config.micro_batch_size * config.sequence_length)).astype( np.uint32), 'segments': np.random.randint( 0, 2, (config.micro_batch_size * config.sequence_length)).astype( np.uint32) } sequence_info = popart.TensorInfo( "UINT32", [config.micro_batch_size * config.sequence_length]) indices = model.builder.addInputTensor(sequence_info) positions = model.builder.addInputTensor(sequence_info) segments = model.builder.addInputTensor(sequence_info) data_popart = {} data_popart[indices] = data['indices'] data_popart[segments] = data['segments'] data_popart[positions] = data['positions'] model(indices, positions, segments) proto = model.builder.getModelProto() # Extract weights from onnx model and check if same number of elements as self.tensors[0] with tempfile.TemporaryDirectory() as tmp: model_path = os.path.join(tmp, "model.onnx") onnx.save(proto, model_path) onnx_model = onnx.load(model_path) assert len(model.tensors[0]) == len(onnx_model.graph.initializer)
def test_pretraining_bwd(custom_ops, opt_type): # ------------------- PopART -------------------- config = BertConfig(task="PRETRAINING", encoder_start_ipu=1, vocab_length=1024, micro_batch_size=1, hidden_size=64, attention_heads=2, sequence_length=20, max_positional_length=20, mask_tokens=2, popart_dtype="FLOAT", activation_type="relu", no_dropout=True, no_attn_dropout=True, update_embedding_dict=True, no_cls_layer=True, no_mask=True, split_qkv=(opt_type == "LAMB")) popart_model = Bert(config) # ------------------- PyTorch ------------------------- torch_model = BertForMaskedLM( TorchBertConfig(config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act="relu", max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, update_embedding_dict=True, mask_tokens=config.mask_tokens)) l1_lambda = 0.1 def popart_loss_fn(logits): loss = popart_model.builder.aiGraphcore.l1loss( [logits[0]], l1_lambda, debugContext="l1LossVal", reduction=popart.ReductionType.Sum) popart_model.builder.virtualGraph(loss, popart_model.mlm_scope.virtualGraph) return loss bwd_graph( popart_model, torch_model, popart_loss_fn=popart_loss_fn, torch_loss_fn=lambda logits: l1_lambda * torch.norm(logits[0], 1), mapping={}, transform=onnx_torch_tform, opt_type=opt_type)
def test_nsp_bwd(custom_ops): # ------------------- PopART -------------------- builder = popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }) config = BertConfig(task="NSP", vocab_length=9728, num_layers=1, batch_size=1, hidden_size=768, sequence_length=128, activation_type="relu", popart_dtype="FLOAT", no_dropout=True, custom_ops=["gather", "attention"]) popart_model = Bert(config, builder=builder) # ------------------- PyTorch ------------------------- torch_model = BertForNextSentencePrediction( TorchBertConfig(config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act="relu", max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, mask_tokens=config.mask_tokens, num_labels=2)) def popart_loss_fn(outputs): loss = popart.L1Loss(outputs[0], "l1Loss", 0.1) loss.virtualGraph(popart_model.nsp_scope.virtualGraph) return [loss] def torch_loss_fn(outputs): return 0.1 * torch.norm(outputs[0], 1) bwd_graph(popart_model, torch_model, popart_loss_fn=popart_loss_fn, torch_loss_fn=torch_loss_fn, mapping={ "bert.pooler.dense.weight": "NSP/PoolW", "bert.pooler.dense.bias": "NSP/PoolB", "cls.seq_relationship.weight": "NSP/NspW", "cls.seq_relationship.bias": "NSP/NspB" }, transform={ "bert.pooler.dense.weight": np.transpose, "cls.seq_relationship.weight": np.transpose })
def _load_edge_model(self, bert_model_file, bert_config_file): bert_config = BertConfig.from_json_file(bert_config_file) model = BertEdgeScorer(bert_config) model_states = torch.load(bert_model_file) print(model_states.keys()) model.bert.load_state_dict(model_states) model.cuda() model.eval() return model
def create_dataset(args): # a simple copy of main bert.py until the dataset creation config = BertConfig() model = Bert(config, builder=popart.Builder()) indices, positions, segments, masks, labels = bert_add_inputs( args, model) inputs = [indices, positions, segments, masks, labels] embedding_dict, positional_dict = model.get_model_embeddings() shapeOf = model.builder.getTensorShape inputs = reduce(chain, inputs[3:], inputs[:3]) tensor_shapes = [(tensorId, shapeOf(tensorId)) for tensorId in inputs] dataset = get_bert_dataset(args, tensor_shapes) return dataset
def test_nsp_bwd(custom_ops, opt_type): # ------------------- PopART -------------------- config = BertConfig(task="NSP", vocab_length=2432, micro_batch_size=1, hidden_size=288, sequence_length=128, activation_type="relu", popart_dtype="FLOAT", no_dropout=True, no_attn_dropout=True, no_mask=True, update_embedding_dict=True, split_qkv=(opt_type == "LAMB")) popart_model = Bert(config) # ------------------- PyTorch ------------------------- torch_model = BertForNextSentencePrediction( TorchBertConfig(config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act=config.activation_type, max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, mask_tokens=config.mask_tokens, update_embedding_dict=True, num_labels=2)) l1_lambda = 0.1 def popart_loss_fn(outputs): loss = popart_model.builder.aiGraphcore.l1loss( [outputs[0]], l1_lambda, debugContext="l1LossVal", reduction=popart.ReductionType.Sum) popart_model.builder.virtualGraph(loss, popart_model.nsp_scope.virtualGraph) return loss def torch_loss_fn(outputs): return l1_lambda * torch.norm(outputs[0], 1) bwd_graph(popart_model, torch_model, popart_loss_fn=popart_loss_fn, torch_loss_fn=torch_loss_fn, mapping=NSP_MAPPING, transform=NSP_TRANSFORM, opt_type=opt_type)
def test_activation_function(mode, phase, momentum, micro_batch_size, batch_serialization_factor): set_library_seeds(0) popart_act_function, pytorch_activation = ACTIVATIONS["Gelu"] config = BertConfig(vocab_length=128, micro_batch_size=micro_batch_size, hidden_size=768, sequence_length=128, popart_dtype="FLOAT", no_dropout=True, activation_type=str(popart_act_function)) data, outputs, proto, post_proto = popart_result_and_model( config, mode, batch_serialization_factor, is_bwd=False if phase is 'fwd' else True, momentum=momentum) inputs = [ data.reshape(config.micro_batch_size, config.sequence_length, config.hidden_size) ] # ------------------- PyTorch ------------------------- torch_config = TorchBertConfig(config.vocab_length, config.hidden_size, config.num_layers, config.attention_heads, layer_norm_eps=config.layer_norm_eps, hidden_dropout_prob=0., hidden_act=pytorch_activation) torch_output, torch_model = pytorch_result_and_model( torch_config, inputs, proto, mode, is_bwd=False if phase is 'fwd' else True, momentum=momentum) check_tensors(torch_output, outputs, margin=7e-6) if phase is 'bwd': check_model(torch_model, post_proto, TORCH_TO_ONNX[mode], transform=TRANSPOSE_WEIGHTS, margin=7e-6)
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): """Convert a bert model checkpoint for TensorFlow to PyTorch.""" # Initialise PyTorch model config = BertConfig.from_json_file(bert_config_file) print("Building PyTorch model from configuration: {}".format(str(config))) model = BertForPreTraining(config) # Load weights from tf checkpoint load_tf_weights_in_bert(model, tf_checkpoint_path) # Save pytorch-model print("Save PyTorch model to {}".format(pytorch_dump_path)) torch.save(model.state_dict(), pytorch_dump_path)
def test_pretraining_bwd(custom_ops): # ------------------- PopART -------------------- builder = popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }) config = BertConfig(task="PRETRAINING", vocab_length=9728, projection_serialization_steps=4, num_layers=1, batch_size=1, hidden_size=768, sequence_length=128, popart_dtype="FLOAT", activation_type="relu", no_dropout=True, update_embedding_dict=False) popart_model = Bert(config, builder=builder) # ------------------- PyTorch ------------------------- torch_model = BertForMaskedLM( TorchBertConfig(config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act="relu", max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, mask_tokens=config.mask_tokens)) l1_lambda = 0.1 def popart_loss_fn(logits): loss = builder.aiGraphcore.l1loss([logits[0]], l1_lambda, debugPrefix="l1LossVal", reduction=popart.ReductionType.Sum) builder.virtualGraph(loss, popart_model.mlm_scope.virtualGraph) return loss bwd_graph( popart_model, torch_model, popart_loss_fn=popart_loss_fn, torch_loss_fn=lambda logits: l1_lambda * torch.norm(logits[0], 1), mapping=onnx_torch_mapping, transform=onnx_torch_tform)
def test_nsp_bwd(custom_ops): # ------------------- PopART -------------------- builder = popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }) config = BertConfig(task="NSP", vocab_length=9728, num_layers=1, batch_size=1, hidden_size=768, sequence_length=128, activation_type="relu", popart_dtype="FLOAT", no_dropout=True) popart_model = Bert(config, builder=builder) # ------------------- PyTorch ------------------------- torch_model = BertForNextSentencePrediction( TorchBertConfig(config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act=config.activation_type, max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, mask_tokens=config.mask_tokens, num_labels=2)) def popart_loss_fn(outputs): loss = builder.aiGraphcore.l1loss([outputs[0]], 0.1, debugPrefix="l1Loss", reduction=popart.ReductionType.Sum) builder.virtualGraph(loss, popart_model.nsp_scope.virtualGraph) return loss def torch_loss_fn(outputs): return 0.1 * torch.norm(outputs[0], 1) bwd_graph(popart_model, torch_model, popart_loss_fn=popart_loss_fn, torch_loss_fn=torch_loss_fn, mapping=NSP_MAPPING, transform=NSP_TRANSFORM)
def test_squad_fwd(): # ------------------- PopART -------------------- builder = popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }) config = BertConfig(task="SQUAD", vocab_length=9728, num_layers=2, batch_size=1, hidden_size=768, sequence_length=128, activation_type="relu", popart_dtype="FLOAT", no_dropout=True, custom_ops=[], inference=True) popart_model = Bert(config, builder=builder) # ------------------- PyTorch ------------------------- torch_model = BertForQuestionAnswering( TorchBertConfig(config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act="relu", max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, mask_tokens=config.mask_tokens, num_labels=2)) fwd_graph(popart_model, torch_model, mapping={ "cls.transform.dense.weight": "CLS/LMPredictionW", "cls.transform.dense.bias": "CLS/LMPredictionB", "cls.transform.LayerNorm.weight": "CLS/Gamma", "cls.transform.LayerNorm.bias": "CLS/Beta", "qa_outputs.weight": "Squad/SquadW", "qa_outputs.bias": "Squad/SquadB" }, transform={ "cls.transform.dense.weight": np.transpose, "qa_outputs.weight": np.transpose })
def test_weight_decay(weight_decay): lr = 0.01 l1_lambda = 0.1 # ------------------- PopART ------------------------- config = BertConfig(vocab_length=128, batch_size=1, hidden_size=768, sequence_length=128, popart_dtype="FLOAT", no_dropout=True, custom_ops=[], activation_type='Gelu') data, outputs, proto, post_proto = popart_result_and_model( config, weight_decay=weight_decay, lr=lr, l1_lambda=l1_lambda) # ------------------- PyTorch ------------------------- torch_config = TorchBertConfig(config.vocab_length, config.hidden_size, config.num_layers, config.attention_heads, layer_norm_eps=config.layer_norm_eps, hidden_dropout_prob=0., hidden_act=nn.functional.gelu) inputs = [ data.reshape(config.batch_size, config.sequence_length, config.hidden_size) ] torch_output, torch_model = pytorch_result_and_model( torch_config, inputs, proto, weight_decay=weight_decay, lr=lr, l1_lambda=l1_lambda) # ------------------- Check outputs ------------------------- check_tensors(torch_output, outputs) check_model(torch_model, post_proto, TORCH_TO_ONNX, transform=TRANSPOSE_WEIGHTS)
def test_pretraining_fwd(custom_ops, mode, replication_factor, replicated_tensor_sharding): # ------------------- PopART -------------------- config = BertConfig(task="PRETRAINING", encoder_start_ipu=1, vocab_length=1024, micro_batch_size=1, hidden_size=64, attention_heads=2, sequence_length=20, max_positional_length=20, mask_tokens=2, popart_dtype="FLOAT", activation_type="relu", no_dropout=True, no_attn_dropout=True, no_cls_layer=False, inference=True, no_mask=True, execution_mode=mode, split_qkv=False) popart_model = get_model(config, mode) # ------------------- PyTorch ------------------------- torch_model = BertForMaskedLM( TorchBertConfig(config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act="relu", max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, mask_tokens=config.mask_tokens, no_cls_layer=config.no_cls_layer)) fwd_graph(popart_model, torch_model, mode, mapping=ONNX_TORCH_MAPPING[mode], transform=onnx_torch_tform, replication_factor=replication_factor, replicated_tensor_sharding=replicated_tensor_sharding)
def test_squad_fwd(mode, replication_factor, replicated_weight_sharding): split_qkv = False # ------------------- PopART -------------------- config = BertConfig(task="SQUAD", vocab_length=9728, num_layers=2, batch_size=1, hidden_size=768, sequence_length=128, activation_type="relu", popart_dtype="FLOAT", no_dropout=True, no_attn_dropout=True, inference=True, no_mask=True, execution_mode=mode, split_qkv=split_qkv, squad_single_output=False) popart_model = get_model(config, mode) # ------------------- PyTorch ------------------------- torch_model = BertForQuestionAnswering( TorchBertConfig(config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act="relu", max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, mask_tokens=config.mask_tokens, num_labels=2)) fwd_graph(popart_model, torch_model, mode, mapping=ONNX_TORCH_MAPPING[mode], transform={"qa_outputs.weight": np.transpose}, replication_factor=replication_factor, replicated_weight_sharding=replicated_weight_sharding)
def load_bert_config_tf(config_path): """ Load the bert config data from Google Research's checkpoint format into the Popart Bert config format. """ with open(config_path, "r") as fh: config_data = json.load(fh) config = BertConfig( vocab_length=config_data["vocab_size"], hidden_size=config_data["hidden_size"], sequence_length=config_data["max_position_embeddings"], ff_size__=config_data["intermediate_size"], attention_heads=config_data["num_attention_heads"], num_layers=config_data["num_hidden_layers"], # TODO: Read the rest of these in from a GC config? popart_dtype="FLOAT", no_dropout=True, ) return config
def load_model(model_path, device): """ Load the pretrained model states and prepare the model for sentiment analysis. Parameters ---------- model_path: str Path to the pretrained model states binary file. device: torch.device Device to load the model on. Returns ------- model: BertForSequenceClassification Model with the loaded pretrained states """ config = BertConfig(vocab_size=30522, type_vocab_size=2) model = BertForSequenceClassification(config, 2, [11]) model_states = torch.load(model_path, map_location=device) model.load_state_dict(model_states) model.eval() return model
def test_squad_fwd(custom_ops): # ------------------- PopART -------------------- config = BertConfig(task="SQUAD", encoder_start_ipu=1, vocab_length=1024, micro_batch_size=1, hidden_size=64, attention_heads=2, sequence_length=20, max_positional_length=20, activation_type="relu", popart_dtype="FLOAT", no_dropout=True, no_attn_dropout=True, inference=True, no_mask=True, split_qkv=False, squad_single_output=False) popart_model = Bert(config) # ------------------- PyTorch ------------------------- torch_model = BertForQuestionAnswering( TorchBertConfig(config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act="relu", max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, mask_tokens=2, num_labels=2)) fwd_graph(popart_model, torch_model, mapping=ONNX_TORCH_MAPPING, transform={"qa_outputs.weight": np.transpose})
def test_activation_function(activation_function, phase, custom_ops): popart_act_function, pytorch_activation = ACTIVATIONS[activation_function] config = BertConfig(vocab_length=128, batch_size=1, hidden_size=768, sequence_length=128, popart_dtype="FLOAT", no_dropout=True, custom_ops=[], activation_type=str(popart_act_function)) data, outputs, proto, post_proto = popart_result_and_model( config, is_bwd=False if phase is 'fwd' else True) inputs = [ data.reshape(config.batch_size, config.sequence_length, config.hidden_size) ] # ------------------- PyTorch ------------------------- torch_config = TorchBertConfig(config.vocab_length, config.hidden_size, config.num_layers, config.attention_heads, layer_norm_eps=config.layer_norm_eps, hidden_dropout_prob=0., hidden_act=pytorch_activation) torch_output, torch_model = pytorch_result_and_model( torch_config, inputs, proto, is_bwd=False if phase is 'fwd' else True) check_tensors(torch_output, outputs) if phase is 'bwd': check_model(torch_model, post_proto, TORCH_TO_ONNX, transform=TRANSPOSE_WEIGHTS)
def load_deprecated_model(model_path): """ Load the pretrained model states and prepare the model for sentiment analysis on CPU. This method returns a custom BertForSequenceClassification model that allows it to work with LayerIntegratedGradients and LayerIntermediateGradients. Parameters ---------- model_path: str Path to the pretrained model states binary file. Returns ------- model: BertForSequenceClassification Model with the loaded pretrained states. """ config = BertConfig(vocab_size=30522, type_vocab_size=2) model = BertForSequenceClassification(config, 2, [11]) model_states = torch.load(model_path, map_location=torch.device("cpu")) model.load_state_dict(model_states) model.eval() return model
def bert_config_from_args(args): return BertConfig( ** {k: getattr(args, k) for k in BertConfig._fields if hasattr(args, k)})