def fwd_graph(popart_model, torch_model, mapping=None, transform=None): # ------------------- PopART -------------------- config = popart_model.config builder = popart_model.builder sequence_info = popart.TensorInfo( "INT32", [config.batch_size * config.sequence_length]) indices = builder.addInputTensor(sequence_info) positions = builder.addInputTensor(sequence_info) segments = builder.addInputTensor(sequence_info) data = { indices: np.random.randint(0, config.vocab_length, (config.batch_size * config.sequence_length)).astype( np.int32), positions: np.random.randint(0, config.sequence_length, (config.batch_size * config.sequence_length)).astype( np.int32), segments: np.random.randint(0, 2, (config.batch_size * config.sequence_length)).astype( np.int32) } output = popart_model.build_graph(indices, positions, segments) proto = builder.getModelProto() outputs, post_proto = run_py( proto, data, output, ipus=math.ceil(config.num_layers / config.layers_per_ipu) + popart_model.layer_offset) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = { "input_ids": data[indices].reshape(config.batch_size, config.sequence_length), "position_ids": data[positions].reshape(config.batch_size, config.sequence_length), "token_type_ids": data[segments].reshape(config.batch_size, config.sequence_length) } torch_to_onnx = get_mapping(config, init=mapping) transform_weights = get_transform(config, init=transform) # ------------------- PyTorch ------------------------- # Turn off dropout torch_model.eval() copy_weights_to_torch(torch_model, proto, torch_to_onnx, transform_weights) torch_outputs = run_fwd_model(inputs, torch_model) check_tensors(torch_outputs, outputs)
def pytorch_result_and_model(torch_config, inputs, popart_proto, is_bwd=False): # Conversion of the popart model to onnx proto = onnx.load_model_from_string(popart_proto) torch_model = BertFCN(torch_config) # Turn off dropout torch_model.eval() copy_weights_to_torch(torch_model, proto, TORCH_TO_ONNX, transform=TRANSPOSE_WEIGHTS) result = run_fwd_model(inputs, torch_model) if is_bwd: l1_lambda = 0.1 optim = torch.optim.SGD(torch_model.parameters(), 0.01, weight_decay=0.0, momentum=0.0) result = torch_model(*[torch.from_numpy(t).float() for t in inputs])[0] torch_loss = l1_lambda * torch.norm(result, 1) torch_loss.backward() optim.step() result = result.detach().numpy() return result, torch_model
def pytorch_result_and_model(torch_config, inputs, popart_proto, weight_decay=0.0, lr=0.0, l1_lambda=0.0): proto = onnx.load_model_from_string(popart_proto) torch_model = BertFCN(torch_config) torch_model.eval() # Turn off dropout copy_weights_to_torch(torch_model, proto, TORCH_TO_ONNX, transform=TRANSPOSE_WEIGHTS) run_fwd_model(inputs, torch_model) decay = [] no_decay = [] for name, param in torch_model.named_parameters(): if "bias" in name or "LayerNorm" in name: no_decay.append(param) else: decay.append(param) params = [{ 'params': no_decay, 'weight_decay': 0. }, { 'params': decay, 'weight_decay': weight_decay }] optim = torch.optim.SGD(params, lr, momentum=0.0) result = torch_model(*[torch.from_numpy(t).float() for t in inputs])[0] torch_loss = l1_lambda * torch.norm(result, 1) torch_loss.backward() optim.step() result = result.detach().numpy() return result, torch_model
def run_models(config, proto, indices, positions, segments, output, popart_model, torch_model): onnx_proto = onnx.load_model_from_string(proto) check_model(torch_model, onnx_proto, get_mapping(config), get_transform(config)) # Run the models popart_inputs = { indices: np.random.randint(0, config.vocab_length, (config.batch_size * config.sequence_length)).astype( np.uint32), positions: np.random.randint( 0, config.sequence_length, (config.batch_size * config.sequence_length), ).astype(np.uint32), segments: np.random.randint( 0, 2, (config.batch_size * config.sequence_length), ).astype(np.uint32), } popart_outputs, post_proto = run_py( proto, popart_inputs, output, ipus=popart_model.total_ipus, ) torch_inputs = { "input_ids": popart_inputs[indices].reshape(config.batch_size, config.sequence_length), "position_ids": popart_inputs[positions].reshape(config.batch_size, config.sequence_length), "token_type_ids": popart_inputs[segments].reshape(config.batch_size, config.sequence_length), } torch_model.eval() torch_outputs = run_fwd_model(torch_inputs, torch_model) check_model(torch_model, post_proto, get_mapping(config), get_transform(config)) check_tensors(torch_outputs, popart_outputs) print("Test succeeded")
def pytorch_result_and_model(config, inputs, popart_proto, weight_transposed, is_bwd=False): """Run pytorch model based on config. Args: config (BertConfig): Popart config. inputs (np.ndarray): Input np array. popart_proto (onnx.proto): Onnx protobuf. weight_transposed (bool): If True, onnx weights are constructed transposed. is_bwd (bool, optional): True if bwd_pass. Defaults to False. Returns: Tuple: Output np.array and Torch model. """ torch_config = TorchBertConfig(config.vocab_length, config.hidden_size, config.num_layers, config.attention_heads, layer_norm_eps=config.layer_norm_eps) torch_model = nn.Embedding(torch_config.vocab_size, torch_config.hidden_size, padding_idx=0) # Turn off dropout torch_model.eval() # Conversion of the popart model to onnx proto = onnx.load_model_from_string(popart_proto) initializers = get_initializers(proto, weight_transposed) for name, weight in torch_model.named_parameters(): weight.data.copy_(torch.from_numpy(initializers[name]).float()) result = run_fwd_model(inputs, torch_model) if is_bwd: optim = torch.optim.SGD(torch_model.parameters(), 0.01, weight_decay=0.0, momentum=0.0) result = torch_model(*[torch.from_numpy(t).long() for t in inputs])[0] torch_loss = 0.1 * torch.norm(result, 1) torch_loss.backward() optim.step() result = [result.detach().numpy()] return result, torch_model
def pytorch_result_and_model(torch_config, inputs, popart_proto, mode, is_bwd=False, momentum=0.0): # Conversion of the popart model to onnx proto = onnx.load_model_from_string(popart_proto) torch_model = BertFCN(torch_config) # Turn off dropout torch_model.eval() copy_weights_to_torch(torch_model, proto, TORCH_TO_ONNX[mode], transform=TRANSPOSE_WEIGHTS) result = run_fwd_model(inputs, torch_model) if is_bwd: l1_lambda = 0.1 optim = torch.optim.SGD(torch_model.parameters(), lr, weight_decay=0.0, momentum=momentum) if momentum > 0.0: for group in optim.param_groups: for p in group['params']: optim.state[p]['momentum_buffer'] = p.data * 0. optim.state[p]['exp_avg'] = p.data * 0. optim.state[p]['exp_avg_sq'] = p.data * 0. optim.state[p]['step'] = 0 for _ in range(num_reps_bwd): result = torch_model( *[torch.from_numpy(t).float() for t in inputs])[0] torch_loss = l1_lambda * torch.norm(result, 1) torch_loss.backward() optim.step() optim.zero_grad() result = [result.detach().numpy()] return result, torch_model
def test_embedding_fwd(custom_ops): # ------------------- PopART -------------------- config = BertConfig(task="SQUAD", vocab_length=9728, micro_batch_size=1, hidden_size=768, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, inference=True) popart_model = Bert(config) sequence_info = popart.TensorInfo( "UINT32", [config.micro_batch_size * config.sequence_length]) indices = popart_model.builder.addInputTensor(sequence_info) positions = popart_model.builder.addInputTensor(sequence_info) segments = popart_model.builder.addInputTensor(sequence_info) data = { indices: np.random.randint( 0, config.vocab_length, (config.micro_batch_size * config.sequence_length)).astype( np.uint32), positions: np.random.randint( 0, config.max_positional_length, (config.micro_batch_size * config.sequence_length)).astype( np.uint32), segments: np.random.randint( 0, 2, (config.micro_batch_size * config.sequence_length)).astype( np.uint32) } user_options = {"enableStochasticRounding": True} with popart_model.builder.nameScope("Embedding"): output = popart_model.embedding(indices, positions, segments) proto = popart_model.builder.getModelProto() outputs, post_proto = run_py(proto, data, output, user_options=user_options) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = [ data[t].reshape(config.micro_batch_size, config.sequence_length).astype(np.int32) for t in [indices, positions, segments] ] # ------------------- PyTorch ------------------------- torch_model = BertEmbeddings( TorchBertConfig(config.vocab_length, config.hidden_size, max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps)) torch_model.eval() copy_weights_to_torch(torch_model, proto, TORCH_TO_ONNX, {}) torch_outputs = run_fwd_model(inputs, torch_model) check_tensors(torch_outputs, outputs, margin=5e-7)
def test_embedding_fwd(custom_ops): # ------------------- PopART -------------------- builder = popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }) config = BertConfig(vocab_length=9728, batch_size=1, hidden_size=768, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, custom_ops=['gather'], inference=True) popart_model = Bert(config, builder=builder) # Prevent virtualGraph attributes being added to the ops. popart_model.embedding_scope = popart_model.device_scope(None, None) popart_model.embedding_split_scope = popart_model.embedding_scope sequence_info = popart.TensorInfo( "UINT32", [config.batch_size * config.sequence_length]) indices = builder.addInputTensor(sequence_info) positions = builder.addInputTensor(sequence_info) segments = builder.addInputTensor(sequence_info) data = { indices: np.random.randint(0, config.vocab_length, (config.batch_size * config.sequence_length)).astype( np.uint32), positions: np.random.randint(0, config.max_positional_length, (config.batch_size * config.sequence_length)).astype( np.uint32), segments: np.random.randint(0, 2, (config.batch_size * config.sequence_length)).astype( np.uint32) } # Use the custom embedding for layout output = popart_model.embedding(indices, positions, segments) proto = builder.getModelProto() outputs, post_proto = run_py( proto, data, output, user_options={"enableStochasticRounding": True}) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = [ data[t].reshape(config.batch_size, config.sequence_length).astype(np.int32) for t in [indices, positions, segments] ] torch_to_onnx = { "word_embeddings.weight": "Embedding_Dict", "position_embeddings.weight": "Positional_Dict", "token_type_embeddings.weight": "Segment_Dict", "LayerNorm.weight": "Gamma", "LayerNorm.bias": "Beta" } transposed_weights = { "word_embeddings.weight": np.transpose, "position_embeddings.weight": np.transpose, } # ------------------- PyTorch ------------------------- torch_model = BertEmbeddings( TorchBertConfig(config.vocab_length, config.hidden_size, max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps)) torch_model.eval() copy_weights_to_torch(torch_model, proto, torch_to_onnx, transposed_weights) torch_outputs = run_fwd_model(inputs, torch_model) check_tensors(torch_outputs, outputs)
def test_embedding_projection_fwd(custom_ops): # ------------------- PopART -------------------- builder = popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }) config = BertConfig(vocab_length=9728, embedding_serialization_vocab_steps=4, micro_batch_size=1, hidden_size=768, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, no_cls_layer=False, inference=True) popart_model = Bert(config, builder=builder) sequence_info = popart.TensorInfo( "UINT32", [config.micro_batch_size * config.sequence_length]) indices = builder.addInputTensor(sequence_info) data = { indices: np.random.randint(0, config.vocab_length, (config.micro_batch_size * config.sequence_length)).astype( np.uint32) } x = popart_model.gather( indices, config.vocab_length, "Embedding_Dict") x = popart_model.norm(x) x = popart_model.dropout(x) with popart_model.builder.nameScope("CLS"): x = popart_model.lm_prediction_head(x) output = popart_model.projection(x) proto = builder.getModelProto() outputs, post_proto = run_py(proto, data, output) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = [data[indices].reshape( config.micro_batch_size, config.sequence_length).astype(np.int32)] # ------------------- PyTorch ------------------------- torch_model = EmbeddingProjectionModel( TorchBertConfig(config.vocab_length, config.hidden_size, max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, no_cls_layer=config.no_cls_layer)) torch_model.eval() copy_weights_to_torch(torch_model, proto, TORCH_TO_ONNX, TRANSPOSE_WEIGHTS) torch_model.tie_weights() torch_outputs = run_fwd_model(inputs, torch_model) check_tensors(torch_outputs, outputs)
def test_embedding(config, phase): # define input indices = np.random.randint( 0, test_config.vocab_size, (test_config.batch_size, test_config.sequence_length)).astype(np.int32) positions = np.reshape( np.arange(test_config.sequence_length), (test_config.batch_size, test_config.sequence_length)).astype(np.int32) segments = np.random.randint( 0, 2, (test_config.batch_size, test_config.sequence_length)).astype(np.int32) inputs = [d for d in [indices, positions, segments]] # build model # PyTorch model torch_config = TorchBertConfig( vocab_size_or_config_json_file=test_config.vocab_size, hidden_size=test_config.hidden_size, hidden_act=test_config.hidden_act, num_attention_heads=test_config.num_attention_heads, hidden_dropout_prob=test_config.hidden_dropout_prob, max_position_embeddings=test_config.max_position_embeddings, type_vocab_size=test_config.type_vocab_size, update_embedding_dict=True, layer_norm_eps=test_config.layer_norm_eps) torch_model = TorchBertEmbeddings(torch_config) torch_model.eval() # TF model tf_config = TFBertConfig( vocab_size=test_config.vocab_size, hidden_size=test_config.hidden_size, hidden_act=test_config.hidden_act, num_attention_heads=test_config.num_attention_heads, max_position_embeddings=test_config.max_position_embeddings, max_predictions_per_seq=test_config.max_predictions_per_seq, hidden_dropout_prob=test_config.hidden_dropout_prob, type_vocab_size=test_config.type_vocab_size, initializer_range=test_config.initializer_range, dtype=test_config.dtype, matmul_serialize_factor=test_config.matmul_serialize_factor, static_mask=False) # farward check if phase == "fwd": torch_outputs = run_fwd_model(inputs, torch_model) with tf.Graph().as_default(): tf_model = TFBertModel(tf_config, is_training=True) with ops.device('cpu'): input_ids = tf.placeholder(shape=[ test_config.batch_size, test_config.sequence_length ], dtype=tf.int32) position_ids = tf.placeholder(shape=[ test_config.batch_size, test_config.sequence_length ], dtype=tf.int32) segment_ids = tf.placeholder(shape=[ test_config.batch_size, test_config.sequence_length ], dtype=tf.int32) cfg = utils.create_ipu_config() cfg = utils.auto_select_ipus(cfg, 1) utils.configure_ipu_system(cfg) utils.move_variable_initialization_to_cpu() with ops.device("/device:IPU:0"): opt = ipu_compiler.compile( tf_model.embeddings_layer, inputs=[input_ids, position_ids, segment_ids]) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # copy pytorch weight to tf var_and_init = copy_torch_weights_to_tf( torch_model, tf_model, TF_TO_TORCH, {}, sess) sess.run(var_and_init) # run tf feed feed farward tf_outputs = sess.run( opt, { input_ids: indices, position_ids: positions, segment_ids: segments }) # compare tf output with pytorch output check_tensors(tf_outputs, torch_outputs, margin=1.5e-8) # backward check elif phase == "bwd": l1_lambda = 0.1 base_lr = 0.01 optim = torch.optim.SGD(torch_model.parameters(), base_lr, weight_decay=0.0, momentum=0.0) torch_output = torch_model( *[torch.from_numpy(t).long() for t in inputs]) # pytorch backward torch_loss = l1_lambda * torch.norm(torch_output, 1) torch_loss.backward() # calculate gradients optim.step() # update gradients torch_outputs = [torch_output.detach().numpy()] # TF with tf.Graph().as_default(): tf_model = TFBertModel(tf_config, is_training=True) with ops.device('cpu'): input_ids = tf.placeholder(shape=[ test_config.batch_size, test_config.sequence_length ], dtype=tf.int32) position_ids = tf.placeholder(shape=[ test_config.batch_size, test_config.sequence_length ], dtype=tf.int32) segment_ids = tf.placeholder(shape=[ test_config.batch_size, test_config.sequence_length ], dtype=tf.int32) cfg = utils.create_ipu_config() cfg = utils.auto_select_ipus(cfg, 1) utils.configure_ipu_system(cfg) utils.move_variable_initialization_to_cpu() def embedding_graph(input_ids, position_ids, segment_ids): embedding_output = tf_model.embeddings_layer( input_ids, position_ids, segment_ids) l1_loss = l1_lambda * tf.norm(embedding_output, 1) optimizer = tf.train.GradientDescentOptimizer(base_lr) train_step = optimizer.minimize(l1_loss) return embedding_output, l1_loss, train_step with ops.device("/device:IPU:0"): opt = ipu_compiler.compile( embedding_graph, inputs=[input_ids, position_ids, segment_ids]) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) var_and_init = copy_torch_weights_to_tf( torch_model, tf_model, TF_TO_TORCH, {}, sess) sess.run(var_and_init) tvars = sess.run({v.name: v for v in tf.trainable_variables()}) print(tvars) tf_outputs, tf_loss = sess.run( opt, { input_ids: indices, position_ids: positions, segment_ids: segments }) # sess.run(opt, {input_ids: indices, position_ids: positions, segment_ids: segments}) # Compare the farward output check_tf_torch_model(sess, torch_model, TF_TO_TORCH, margin=5e-7) check_tensors(torch_outputs, tf_outputs, margin=5e-7) else: raise ValueError( f"`phase` only can be set to [`fwd`, `bwd`] which mean farward or backward respectively." )
def test_embedding_fwd(custom_ops, mode, batch_size, batch_serialization_factor, embedding_serialization_vocab_steps): # ------------------- PopART -------------------- config = BertConfig( task="SQUAD", vocab_length=9728, batch_size=batch_size, hidden_size=768, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, inference=True, embedding_serialization_vocab_steps=embedding_serialization_vocab_steps ) popart_model = get_model(config, mode, 'embedding') sequence_info = popart.TensorInfo( "UINT32", [config.batch_size * config.sequence_length]) indices = popart_model.builder.addInputTensor(sequence_info) positions = popart_model.builder.addInputTensor(sequence_info) segments = popart_model.builder.addInputTensor(sequence_info) data = { indices: np.random.randint(0, config.vocab_length, (config.batch_size * config.sequence_length)).astype( np.uint32), positions: np.random.randint(0, config.max_positional_length, (config.batch_size * config.sequence_length)).astype( np.uint32), segments: np.random.randint(0, 2, (config.batch_size * config.sequence_length)).astype( np.uint32) } user_options = {} if mode == ExecutionMode.PHASED: user_options = { "batchSerializationFactor": batch_serialization_factor, "executionPhases": popart_model.total_execution_phases } output = popart_model(indices, positions, segments) else: user_options = {"enableStochasticRounding": True} with popart_model.builder.nameScope("Embedding"): output = popart_model.embedding(indices, positions, segments) proto = popart_model.builder.getModelProto() outputs, post_proto = run_py(proto, data, output, user_options=user_options, execution_mode=mode) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = [ data[t].reshape(config.batch_size, config.sequence_length).astype(np.int32) for t in [indices, positions, segments] ] # ------------------- PyTorch ------------------------- torch_model = BertEmbeddings( TorchBertConfig(config.vocab_length, config.hidden_size, max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps)) torch_model.eval() expanded_name_map, remapped_transform_map = expand_torch_to_onnx_map( TORCH_TO_ONNX[mode], config, mode) copy_weights_to_torch(torch_model, proto, expanded_name_map, remapped_transform_map) torch_outputs = run_fwd_model(inputs, torch_model) check_tensors(torch_outputs, outputs, margin=5e-7)
def test_load_from_chkpt(config_path, chkpt_path, custom_ops): """ Compare the model loaded into our popart model against the modified PyTorch model: - Load tf weights into BERT using torch impl -> run fwd model - Load tf weights into BERT using popart impl -> run fwd model - Compare output tensors """ config = load_bert_config_tf(config_path) builder = popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }) # Load Torch version torch_model = TorchModel( TorchBertConfig( config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act="relu", max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, mask_tokens=config.mask_tokens, )) torch_model.eval() torch_model = load_tf_weights_in_bert(torch_model, config, chkpt_path) # Load Popart model sequence_info = popart.TensorInfo( "INT32", [config.batch_size * config.sequence_length]) indices = builder.addInputTensor(sequence_info) positions = builder.addInputTensor(sequence_info) popart_model, proto, output = load_from_tf(chkpt_path, True, config, indices, positions, builder=builder) # Run the models popart_inputs = { indices: np.random.randint(0, config.vocab_length, (config.batch_size * config.sequence_length)).astype( np.int32), positions: np.random.randint( 0, config.sequence_length, (config.batch_size * config.sequence_length), ).astype(np.int32), } torch_inputs = { "input_ids": popart_inputs[indices].reshape(config.batch_size, config.sequence_length), "position_ids": popart_inputs[positions].reshape(config.batch_size, config.sequence_length), } torch_outputs = run_fwd_model(torch_inputs, torch_model) popart_outputs, post_proto = run_py( proto, popart_inputs, output, ipus=math.ceil(config.num_layers / config.layers_per_ipu) + 1, ) check_tensors(torch_outputs, popart_outputs) print("Test succeeded")
def test_embedding_projection_fwd(custom_ops): # ------------------- PopART -------------------- builder = popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }) config = BertConfig(vocab_length=9728, batch_size=1, hidden_size=768, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, custom_ops=['gather'], inference=True) popart_model = Bert(config, builder=builder) sequence_info = popart.TensorInfo( "INT32", [config.batch_size * config.sequence_length]) indices = builder.addInputTensor(sequence_info) data = { indices: np.random.randint(0, config.vocab_length, (config.batch_size * config.sequence_length)).astype( np.int32) } x = popart_model.embedding_custom( indices, config.vocab_length, "Embedding_Dict", detach=True) x = popart_model.norm(x) x = popart_model.dropout(x) with popart_model.device_scope(nameScope="CLS"): x = popart_model.lm_prediction_head(x) output = popart_model.projection(x) proto = builder.getModelProto() outputs, post_proto = run_py(proto, data, output, user_options={"enableStochasticRounding": True}) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = [data[indices].reshape(config.batch_size, config.sequence_length)] # ------------------- PyTorch ------------------------- torch_model = EmbeddingProjectionModel( TorchBertConfig(config.vocab_length, config.hidden_size, max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps)) torch_model.eval() copy_weights_to_torch(torch_model, proto, torch_to_onnx, transposed_weights) torch_model.tie_weights() torch_outputs = run_fwd_model(inputs, torch_model) check_tensors(torch_outputs, outputs)
def test_attention_fwd(custom_ops): # ------------------- PopART -------------------- builder = popart.Builder() config = BertConfig(task="PRETRAINING", vocab_length=9728, batch_size=1, hidden_size=768, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, custom_ops=['attention'], inference=True) popart_model = Bert(config, builder=builder) input_info = popart.TensorInfo( config.popart_dtype, [config.batch_size * config.sequence_length, config.hidden_size]) input_tensor = builder.addInputTensor(input_info) mask_info = popart.TensorInfo("INT32", [config.batch_size]) mmask_tensor = builder.addInputTensor(mask_info) smask_tensor = builder.addInputTensor(mask_info) data = { input_tensor: np.random.normal(0, 0.02, input_info.shape()).astype(config.dtype), mmask_tensor: np.random.randint(0, config.mask_tokens + 1, (config.batch_size, )).astype(np.int32), smask_tensor: np.random.randint(config.mask_tokens, config.sequence_length + 1, (config.batch_size, )).astype(np.int32) } output = popart_model.attention(input_tensor, [mmask_tensor, smask_tensor]) proto = builder.getModelProto() outputs, post_proto = run_py(proto, data, output) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = [ data[input_tensor].reshape(config.batch_size, config.sequence_length, config.hidden_size).astype(np.float32), get_torch_mask(config, [data[mmask_tensor], data[smask_tensor]]) ] torch_to_onnx = { "self.query.weight": "QKV", "self.key.weight": "QKV", "self.value.weight": "QKV", "output.dense.weight": "Out", "output.LayerNorm.weight": "Gamma", "output.LayerNorm.bias": "Beta" } split_qkv = { "self.query.weight": lambda arr: arr[:, 0:config.hidden_size].T, "self.key.weight": lambda arr: arr[:, config.hidden_size:config.hidden_size * 2].T, "self.value.weight": lambda arr: arr[:, config.hidden_size * 2:config.hidden_size * 3].T, "output.dense.weight": np.transpose } # ------------------- PyTorch ------------------------- torch_model = BertAttention( TorchBertConfig(config.vocab_length, config.hidden_size, config.num_layers, config.attention_heads, layer_norm_eps=config.layer_norm_eps)) # Turn off dropout torch_model.eval() copy_weights_to_torch(torch_model, proto, torch_to_onnx, transform=split_qkv) # Model to test against torch_outputs = run_fwd_model(inputs, torch_model) check_tensors(torch_outputs, outputs)
def test_embedding_fwd(custom_ops): # ------------------- PopART -------------------- builder = popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }) config = BertConfig(task="SQUAD", vocab_length=9728, batch_size=1, hidden_size=768, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, inference=True) popart_model = Bert(config, builder=builder) # Prevent virtualGraph attributes being added to the ops. popart_model.embedding_scope = popart_model.device_scope(None, None) popart_model.embedding_split_scope = popart_model.embedding_scope sequence_info = popart.TensorInfo( "UINT32", [config.batch_size * config.sequence_length]) indices = builder.addInputTensor(sequence_info) positions = builder.addInputTensor(sequence_info) segments = builder.addInputTensor(sequence_info) data = { indices: np.random.randint(0, config.vocab_length, (config.batch_size * config.sequence_length)).astype( np.uint32), positions: np.random.randint(0, config.max_positional_length, (config.batch_size * config.sequence_length)).astype( np.uint32), segments: np.random.randint(0, 2, (config.batch_size * config.sequence_length)).astype( np.uint32) } output = popart_model.embedding(indices, positions, segments) proto = builder.getModelProto() outputs, post_proto = run_py(proto, data, output) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = [ data[t].reshape(config.batch_size, config.sequence_length).astype(np.int32) for t in [indices, positions, segments] ] # ------------------- PyTorch ------------------------- torch_model = BertEmbeddings( TorchBertConfig(config.vocab_length, config.hidden_size, max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps)) torch_model.eval() copy_weights_to_torch(torch_model, proto, TORCH_TO_ONNX, {}) torch_outputs = run_fwd_model(inputs, torch_model) check_tensors(torch_outputs, outputs)
def fwd_graph(popart_model, torch_model, mapping=None, transform=None, replication_factor=1, replicated_tensor_sharding=False): # ------------------- PopART -------------------- config = popart_model.config builder = popart_model.builder sequence_info = popart.TensorInfo( "UINT32", [config.micro_batch_size * config.sequence_length]) indices = builder.addInputTensor(sequence_info) positions = builder.addInputTensor(sequence_info) segments = builder.addInputTensor(sequence_info) data = { indices: np.random.randint(0, config.vocab_length, (replication_factor, config.micro_batch_size * config.sequence_length)).astype(np.uint32), positions: np.random.randint(0, config.sequence_length, (replication_factor, config.micro_batch_size * config.sequence_length)).astype(np.uint32), segments: np.random.randint(0, 2, (replication_factor, config.micro_batch_size * config.sequence_length)).astype(np.uint32) } output = popart_model.build_graph(indices, positions, segments) ipus = popart_model.total_ipus proto = builder.getModelProto() outputs, _ = run_py(proto, data, output, replication_factor=replication_factor, replicated_tensor_sharding=replicated_tensor_sharding, ipus=ipus) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = { "input_ids": data[indices].reshape(replication_factor * config.micro_batch_size, config.sequence_length).astype(np.int32), "position_ids": data[positions].reshape(replication_factor * config.micro_batch_size, config.sequence_length).astype(np.int32), "token_type_ids": data[segments].reshape(replication_factor * config.micro_batch_size, config.sequence_length).astype(np.int32) } torch_to_onnx = get_mapping(config, init=mapping) transform_weights = get_transform(config, init=transform) # ------------------- PyTorch ------------------------- # Turn off dropout torch_model.eval() copy_weights_to_torch(torch_model, proto, torch_to_onnx, transform_weights) torch_outputs = run_fwd_model(inputs, torch_model) check_tensors(torch_outputs, outputs)
def test_attention_fwd(mode, micro_batch_size, batch_serialisation_factor, number_attention_splits, attention_bias, split_qkv): # ------------------- PopART -------------------- config = BertConfig(task="PRETRAINING", vocab_length=9728, micro_batch_size=micro_batch_size, hidden_size=768, attention_heads=4, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, no_attn_dropout=True, inference=True, split_qkv=split_qkv, attention_bias=attention_bias, num_attention_splits=number_attention_splits) popart_model = get_model(config, mode, 'attention') input_info = popart.TensorInfo( config.popart_dtype, [config.micro_batch_size * config.sequence_length, config.hidden_size]) input_tensor = popart_model.builder.addInputTensor(input_info) mask_info = popart.TensorInfo( "UINT32", [config.micro_batch_size, config.sequence_length]) mmask_tensor = popart_model.builder.addInputTensor(mask_info) smask_tensor = popart_model.builder.addInputTensor(mask_info) data = { input_tensor: np.random.normal(0, 0.02, input_info.shape()).astype(config.dtype), mmask_tensor: np.random.randint(0, config.mask_tokens + 1, ( config.micro_batch_size, config.sequence_length, )).astype(np.uint32), smask_tensor: np.random.randint(config.mask_tokens, config.sequence_length + 1, ( config.micro_batch_size, config.sequence_length, )).astype(np.uint32) } user_options = {} if mode == ExecutionMode.PHASED: user_options = { "batchSerializationFactor": batch_serialisation_factor, "executionPhases": popart_model.total_execution_phases } output = popart_model(input_tensor, [mmask_tensor, smask_tensor]) else: user_options = {"enableStochasticRounding": True} output = popart_model.attention(input_tensor, [mmask_tensor, smask_tensor]) proto = popart_model.builder.getModelProto() outputs, post_proto = run_py(proto, data, output, user_options=user_options, execution_mode=mode) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = [ data[input_tensor].reshape(config.micro_batch_size, config.sequence_length, config.hidden_size).astype(np.float32), get_torch_mask(config, [data[mmask_tensor], data[smask_tensor]]) ] # ------------------- PyTorch ------------------------- torch_model = BertAttention( TorchBertConfig(config.vocab_length, config.hidden_size, config.num_layers, config.attention_heads, attention_bias=config.attention_bias, layer_norm_eps=config.layer_norm_eps)) # Turn off dropout torch_model.eval() mapping = TORCH_TO_ONNX[mode] if split_qkv: mapping = TORCH_TO_ONNX_SPLIT_QKV[mode] copy_weights_to_torch(torch_model, proto, mapping, transform=get_transform(split_qkv, config.hidden_size)) # Model to test against torch_outputs = run_fwd_model(inputs, torch_model) check_tensors(torch_outputs, outputs)