def test_embedding_projection_fwd(custom_ops): # ------------------- PopART -------------------- builder = popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }) config = BertConfig(vocab_length=9728, batch_size=1, hidden_size=768, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, custom_ops=['gather'], inference=True) popart_model = Bert(config, builder=builder) sequence_info = popart.TensorInfo( "INT32", [config.batch_size * config.sequence_length]) indices = builder.addInputTensor(sequence_info) data = { indices: np.random.randint(0, config.vocab_length, (config.batch_size * config.sequence_length)).astype( np.int32) } x = popart_model.embedding_custom( indices, config.vocab_length, "Embedding_Dict", detach=True) x = popart_model.norm(x) x = popart_model.dropout(x) with popart_model.device_scope(nameScope="CLS"): x = popart_model.lm_prediction_head(x) output = popart_model.projection(x) proto = builder.getModelProto() outputs, post_proto = run_py(proto, data, output, user_options={"enableStochasticRounding": True}) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = [data[indices].reshape(config.batch_size, config.sequence_length)] # ------------------- PyTorch ------------------------- torch_model = EmbeddingProjectionModel( TorchBertConfig(config.vocab_length, config.hidden_size, max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps)) torch_model.eval() copy_weights_to_torch(torch_model, proto, torch_to_onnx, transposed_weights) torch_model.tie_weights() torch_outputs = run_fwd_model(inputs, torch_model) check_tensors(torch_outputs, outputs)
def test_embedding_projection_bwd(custom_ops): l1_lambda = 0.1 # ------------------- PopART -------------------- builder = popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }) config = BertConfig(vocab_length=9728, batch_size=1, hidden_size=768, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, custom_ops=['gather']) popart_model = Bert(config, builder=builder) sequence_info = popart.TensorInfo( "INT32", [config.batch_size * config.sequence_length]) indices = builder.addInputTensor(sequence_info) data = { indices: np.random.randint(0, config.vocab_length, (config.batch_size * config.sequence_length)).astype( np.int32) } x = popart_model.embedding_custom( indices, config.vocab_length, "Embedding_Dict", detach=True) x = popart_model.norm(x) x = popart_model.dropout(x) with popart_model.device_scope(nameScope="CLS"): x = popart_model.lm_prediction_head(x) output = popart_model.projection(x) proto = builder.getModelProto() l1 = popart.L1Loss(output, "l1LossVal", l1_lambda) optimizer = popart.ConstSGD(0.01) outputs, post_proto = run_py(proto, data, output, loss=l1, optimizer=optimizer, user_options={"enableStochasticRounding": True}) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = [data[indices].reshape(config.batch_size, config.sequence_length)] # ------------------- PyTorch ------------------------- torch_model = EmbeddingProjectionModel( TorchBertConfig(config.vocab_length, config.hidden_size, max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps)) # Turn off dropout torch_model.eval() copy_weights_to_torch(torch_model, proto, torch_to_onnx, transform=transposed_weights) optim = torch.optim.SGD(torch_model.parameters(), 0.01, weight_decay=0.0, momentum=0.0) torch_output = torch_model(*[torch.from_numpy(t).long() for t in inputs]) torch_loss = l1_lambda * torch.norm(torch_output, 1) torch_loss.backward() optim.step() check_tensors([torch_output.detach().numpy()], outputs, margin=1e-5) check_model(torch_model, post_proto, torch_to_onnx, transform=transposed_weights)