def test_attention_bwd(mode, momentum, micro_batch_size, batch_serialisation_factor, number_attention_splits, attention_bias): l1_lambda = 0.1 num_reps = 5 np.random.seed(1984) torch.manual_seed(1984) split_qkv = False # ------------------- PopART -------------------- config = BertConfig(task="PRETRAINING", vocab_length=9728, micro_batch_size=micro_batch_size, hidden_size=768, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, no_attn_dropout=True, split_qkv=split_qkv, attention_bias=attention_bias, num_attention_splits=number_attention_splits) popart_model = get_model(config, mode, 'attention') input_info = popart.TensorInfo( config.popart_dtype, [config.micro_batch_size * config.sequence_length, config.hidden_size]) input_tensor = popart_model.builder.addInputTensor(input_info) mask_info = popart.TensorInfo( "UINT32", [config.micro_batch_size, config.sequence_length]) mmask_tensor = popart_model.builder.addInputTensor(mask_info) smask_tensor = popart_model.builder.addInputTensor(mask_info) data = { input_tensor: np.random.normal(0, 0.02, input_info.shape()).astype(config.dtype), mmask_tensor: np.random.randint(0, config.mask_tokens + 1, ( config.micro_batch_size, config.sequence_length, )).astype(np.uint32), smask_tensor: np.random.randint(config.mask_tokens, config.sequence_length + 1, ( config.micro_batch_size, config.sequence_length, )).astype(np.uint32) } user_options = {} if mode == ExecutionMode.PHASED: user_options = { "batchSerializationFactor": batch_serialisation_factor, "executionPhases": popart_model.total_execution_phases } output = popart_model(input_tensor, [mmask_tensor, smask_tensor]) with popart_model.scope_provider(popart_model.builder, popart_model.norm.scope): l1 = popart_model.builder.aiGraphcore.l1loss( [output], l1_lambda, debugPrefix="l1LossVal", reduction=popart.ReductionType.Sum) else: user_options = {} output = popart_model.attention(input_tensor, [mmask_tensor, smask_tensor]) l1 = popart_model.builder.aiGraphcore.l1loss( [output], l1_lambda, reduction=popart.ReductionType.Sum) proto = popart_model.builder.getModelProto() if momentum: optimizer = popart.SGD({ "defaultLearningRate": (0.01, True), "defaultMomentum": (momentum, True) }) else: optimizer = popart.ConstSGD(0.01) outputs, post_proto = run_py(proto, data, (output, l1), loss=l1, optimizer=optimizer, num_reps=num_reps, user_options=user_options, execution_mode=mode) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = [ data[input_tensor].reshape(config.micro_batch_size, config.sequence_length, config.hidden_size), get_torch_mask(config, [data[mmask_tensor], data[smask_tensor]]) ] # ------------------- PyTorch ------------------------- torch_model = BertAttention( TorchBertConfig(config.vocab_length, config.hidden_size, config.num_layers, config.attention_heads, attention_bias=config.attention_bias, layer_norm_eps=config.layer_norm_eps)) # Turn off dropout torch_model.eval() mapping = TORCH_TO_ONNX[mode] if split_qkv: mapping = TORCH_TO_ONNX_SPLIT_QKV[mode] copy_weights_to_torch(torch_model, proto, mapping, transform=get_transform(split_qkv, config.hidden_size)) optim = torch.optim.SGD(torch_model.parameters(), 0.01, weight_decay=0.0, momentum=momentum) if momentum: for group in optim.param_groups: for p in group['params']: optim.state[p]['momentum_buffer'] = p.data * 0 optim.state[p]['exp_avg'] = p.data * 0 optim.state[p]['exp_avg_sq'] = p.data * 0 optim.state[p]['step'] = 0 for _ in range(num_reps): torch_output = torch_model( *[torch.from_numpy(t).float() for t in inputs])[0] torch_loss = l1_lambda * torch.norm(torch_output, 1) torch_loss.backward() optim.step() optim.zero_grad() check_tensors([torch_output.detach().numpy()], outputs, margin=6e-07) check_model(torch_model, post_proto, mapping, transform=get_transform(split_qkv, config.hidden_size), margin=2e-7)
def test_attention_fwd(mode, micro_batch_size, batch_serialisation_factor, number_attention_splits, attention_bias, split_qkv): # ------------------- PopART -------------------- config = BertConfig(task="PRETRAINING", vocab_length=9728, micro_batch_size=micro_batch_size, hidden_size=768, attention_heads=4, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, no_attn_dropout=True, inference=True, split_qkv=split_qkv, attention_bias=attention_bias, num_attention_splits=number_attention_splits) popart_model = get_model(config, mode, 'attention') input_info = popart.TensorInfo( config.popart_dtype, [config.micro_batch_size * config.sequence_length, config.hidden_size]) input_tensor = popart_model.builder.addInputTensor(input_info) mask_info = popart.TensorInfo( "UINT32", [config.micro_batch_size, config.sequence_length]) mmask_tensor = popart_model.builder.addInputTensor(mask_info) smask_tensor = popart_model.builder.addInputTensor(mask_info) data = { input_tensor: np.random.normal(0, 0.02, input_info.shape()).astype(config.dtype), mmask_tensor: np.random.randint(0, config.mask_tokens + 1, ( config.micro_batch_size, config.sequence_length, )).astype(np.uint32), smask_tensor: np.random.randint(config.mask_tokens, config.sequence_length + 1, ( config.micro_batch_size, config.sequence_length, )).astype(np.uint32) } user_options = {} if mode == ExecutionMode.PHASED: user_options = { "batchSerializationFactor": batch_serialisation_factor, "executionPhases": popart_model.total_execution_phases } output = popart_model(input_tensor, [mmask_tensor, smask_tensor]) else: user_options = {"enableStochasticRounding": True} output = popart_model.attention(input_tensor, [mmask_tensor, smask_tensor]) proto = popart_model.builder.getModelProto() outputs, post_proto = run_py(proto, data, output, user_options=user_options, execution_mode=mode) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = [ data[input_tensor].reshape(config.micro_batch_size, config.sequence_length, config.hidden_size).astype(np.float32), get_torch_mask(config, [data[mmask_tensor], data[smask_tensor]]) ] # ------------------- PyTorch ------------------------- torch_model = BertAttention( TorchBertConfig(config.vocab_length, config.hidden_size, config.num_layers, config.attention_heads, attention_bias=config.attention_bias, layer_norm_eps=config.layer_norm_eps)) # Turn off dropout torch_model.eval() mapping = TORCH_TO_ONNX[mode] if split_qkv: mapping = TORCH_TO_ONNX_SPLIT_QKV[mode] copy_weights_to_torch(torch_model, proto, mapping, transform=get_transform(split_qkv, config.hidden_size)) # Model to test against torch_outputs = run_fwd_model(inputs, torch_model) check_tensors(torch_outputs, outputs)
def test_attention_fwd(custom_ops): # ------------------- PopART -------------------- builder = popart.Builder() config = BertConfig(task="PRETRAINING", vocab_length=9728, batch_size=1, hidden_size=768, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, custom_ops=['attention'], inference=True) popart_model = Bert(config, builder=builder) input_info = popart.TensorInfo( config.popart_dtype, [config.batch_size * config.sequence_length, config.hidden_size]) input_tensor = builder.addInputTensor(input_info) mask_info = popart.TensorInfo("INT32", [config.batch_size]) mmask_tensor = builder.addInputTensor(mask_info) smask_tensor = builder.addInputTensor(mask_info) data = { input_tensor: np.random.normal(0, 0.02, input_info.shape()).astype(config.dtype), mmask_tensor: np.random.randint(0, config.mask_tokens + 1, (config.batch_size, )).astype(np.int32), smask_tensor: np.random.randint(config.mask_tokens, config.sequence_length + 1, (config.batch_size, )).astype(np.int32) } output = popart_model.attention(input_tensor, [mmask_tensor, smask_tensor]) proto = builder.getModelProto() outputs, post_proto = run_py(proto, data, output) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = [ data[input_tensor].reshape(config.batch_size, config.sequence_length, config.hidden_size).astype(np.float32), get_torch_mask(config, [data[mmask_tensor], data[smask_tensor]]) ] torch_to_onnx = { "self.query.weight": "QKV", "self.key.weight": "QKV", "self.value.weight": "QKV", "output.dense.weight": "Out", "output.LayerNorm.weight": "Gamma", "output.LayerNorm.bias": "Beta" } split_qkv = { "self.query.weight": lambda arr: arr[:, 0:config.hidden_size].T, "self.key.weight": lambda arr: arr[:, config.hidden_size:config.hidden_size * 2].T, "self.value.weight": lambda arr: arr[:, config.hidden_size * 2:config.hidden_size * 3].T, "output.dense.weight": np.transpose } # ------------------- PyTorch ------------------------- torch_model = BertAttention( TorchBertConfig(config.vocab_length, config.hidden_size, config.num_layers, config.attention_heads, layer_norm_eps=config.layer_norm_eps)) # Turn off dropout torch_model.eval() copy_weights_to_torch(torch_model, proto, torch_to_onnx, transform=split_qkv) # Model to test against torch_outputs = run_fwd_model(inputs, torch_model) check_tensors(torch_outputs, outputs)
def test_attention_bwd(custom_ops): l1_lambda = 0.1 # ------------------- PopART -------------------- builder = popart.Builder() config = BertConfig(task="PRETRAINING", vocab_length=9728, batch_size=1, hidden_size=768, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, custom_ops=['attention']) popart_model = Bert(config, builder=builder) input_info = popart.TensorInfo( config.popart_dtype, [config.batch_size * config.sequence_length, config.hidden_size]) input_tensor = builder.addInputTensor(input_info) mask_info = popart.TensorInfo("INT32", [config.batch_size]) mmask_tensor = builder.addInputTensor(mask_info) smask_tensor = builder.addInputTensor(mask_info) data = { input_tensor: np.random.normal(0, 0.02, input_info.shape()).astype(config.dtype), mmask_tensor: np.random.randint(0, config.mask_tokens + 1, (config.batch_size, )).astype(np.int32), smask_tensor: np.random.randint(config.mask_tokens, config.sequence_length + 1, (config.batch_size, )).astype(np.int32) } output = popart_model.attention(input_tensor, [mmask_tensor, smask_tensor]) proto = builder.getModelProto() l1 = popart.L1Loss(output, "l1LossVal", l1_lambda) optimizer = popart.ConstSGD(0.01) outputs, post_proto = run_py(proto, data, (output, l1.output(0)), loss=l1, optimizer=optimizer) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = [ data[input_tensor].reshape(config.batch_size, config.sequence_length, config.hidden_size), get_torch_mask(config, [data[mmask_tensor], data[smask_tensor]]) ] torch_to_onnx = { "self.query.weight": "QKV", "self.key.weight": "QKV", "self.value.weight": "QKV", "output.dense.weight": "Out", "output.LayerNorm.weight": "Gamma", "output.LayerNorm.bias": "Beta" } split_qkv = { "self.query.weight": lambda arr: arr[:, 0:config.hidden_size].T, "self.key.weight": lambda arr: arr[:, config.hidden_size:config.hidden_size * 2].T, "self.value.weight": lambda arr: arr[:, config.hidden_size * 2:config.hidden_size * 3].T, "output.dense.weight": np.transpose } # ------------------- PyTorch ------------------------- torch_model = BertAttention( TorchBertConfig(config.vocab_length, config.hidden_size, config.num_layers, config.attention_heads, layer_norm_eps=config.layer_norm_eps)) # Turn off dropout torch_model.eval() copy_weights_to_torch(torch_model, proto, torch_to_onnx, transform=split_qkv) optim = torch.optim.SGD(torch_model.parameters(), 0.01, weight_decay=0.0, momentum=0.0) torch_output = torch_model(*[torch.from_numpy(t).float() for t in inputs])[0] torch_loss = l1_lambda * torch.norm(torch_output, 1) torch_loss.backward() optim.step() check_tensors([torch_output.detach().numpy()], outputs) check_model(torch_model, post_proto, torch_to_onnx, transform=split_qkv)
def test_attention_bwd(mode): l1_lambda = 0.1 # ------------------- PopART -------------------- config = BertConfig(task="PRETRAINING", vocab_length=9728, batch_size=1, hidden_size=768, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, no_attn_dropout=True) popart_model = get_model(config, mode, 'attention') input_info = popart.TensorInfo( config.popart_dtype, [config.batch_size * config.sequence_length, config.hidden_size]) input_tensor = popart_model.builder.addInputTensor(input_info) mask_info = popart.TensorInfo("UINT32", [config.batch_size]) mmask_tensor = popart_model.builder.addInputTensor(mask_info) smask_tensor = popart_model.builder.addInputTensor(mask_info) data = { input_tensor: np.random.normal(0, 0.02, input_info.shape()).astype(config.dtype), mmask_tensor: np.random.randint(0, config.mask_tokens + 1, (config.batch_size, )).astype(np.uint32), smask_tensor: np.random.randint(config.mask_tokens, config.sequence_length + 1, (config.batch_size, )).astype(np.uint32) } user_options = {} if mode == ExecutionMode.PHASED: user_options = { "batchSerializationFactor": 1, "executionPhases": popart_model.total_execution_phases } output = popart_model(input_tensor, [mmask_tensor, smask_tensor]) with popart_model.scope_provider(popart_model.builder, popart_model.norm.scope): l1 = popart_model.builder.aiGraphcore.l1loss( [output], l1_lambda, debugPrefix="l1LossVal", reduction=popart.ReductionType.Sum) else: user_options = {"enableStochasticRounding": True} output = popart_model.attention(input_tensor, [mmask_tensor, smask_tensor]) l1 = popart_model.builder.aiGraphcore.l1loss( [output], l1_lambda, reduction=popart.ReductionType.Sum) proto = popart_model.builder.getModelProto() optimizer = popart.ConstSGD(0.01) outputs, post_proto = run_py(proto, data, (output, l1), loss=l1, optimizer=optimizer, user_options=user_options, execution_mode=mode) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = [ data[input_tensor].reshape(config.batch_size, config.sequence_length, config.hidden_size), get_torch_mask(config, [data[mmask_tensor], data[smask_tensor]]) ] split_qkv = { "self.query.weight": lambda arr: arr[:, 0:config.hidden_size].T, "self.key.weight": lambda arr: arr[:, config.hidden_size:config.hidden_size * 2].T, "self.value.weight": lambda arr: arr[:, config.hidden_size * 2:config.hidden_size * 3].T, "output.dense.weight": np.transpose } # ------------------- PyTorch ------------------------- torch_model = BertAttention( TorchBertConfig(config.vocab_length, config.hidden_size, config.num_layers, config.attention_heads, layer_norm_eps=config.layer_norm_eps)) # Turn off dropout torch_model.eval() copy_weights_to_torch(torch_model, proto, TORCH_TO_ONNX[mode], transform=split_qkv) optim = torch.optim.SGD(torch_model.parameters(), 0.01, weight_decay=0.0, momentum=0.0) torch_output = torch_model(*[torch.from_numpy(t).float() for t in inputs])[0] torch_loss = l1_lambda * torch.norm(torch_output, 1) torch_loss.backward() optim.step() check_tensors([torch_output.detach().numpy()], outputs) check_model(torch_model, post_proto, TORCH_TO_ONNX[mode], transform=split_qkv)
def test_attention_bwd(attention_bias, split_qkv): l1_lambda = 0.1 num_reps = 5 np.random.seed(1984) torch.manual_seed(1984) # ------------------- PopART -------------------- config = BertConfig(task="PRETRAINING", vocab_length=9728, micro_batch_size=1, hidden_size=768, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, no_attn_dropout=True, split_qkv=split_qkv, attention_bias=attention_bias) popart_model = Bert(config, pipeline=True) input_info = popart.TensorInfo( config.popart_dtype, [config.micro_batch_size * config.sequence_length, config.hidden_size]) input_tensor = popart_model.builder.addInputTensor(input_info) mask_info = popart.TensorInfo( "UINT32", [config.micro_batch_size, config.sequence_length]) mmask_tensor = popart_model.builder.addInputTensor(mask_info) smask_tensor = popart_model.builder.addInputTensor(mask_info) data = { input_tensor: np.random.normal(0, 0.02, input_info.shape()).astype(config.dtype), mmask_tensor: np.random.randint(0, config.mask_tokens + 1, ( config.micro_batch_size, config.sequence_length, )).astype(np.uint32), smask_tensor: np.random.randint(config.mask_tokens, config.sequence_length + 1, ( config.micro_batch_size, config.sequence_length, )).astype(np.uint32) } user_options = {} output = popart_model.attention(input_tensor, [mmask_tensor, smask_tensor]) l1 = popart_model.builder.aiGraphcore.l1loss( [output], l1_lambda, reduction=popart.ReductionType.Sum) proto = popart_model.builder.getModelProto() optimizer = popart.ConstSGD(0.01) outputs, post_proto = run_py(proto, data, (output, l1), loss=l1, optimizer=optimizer, num_reps=num_reps, user_options=user_options, pipeline=popart_model.pipeline) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = [ data[input_tensor].reshape(config.micro_batch_size, config.sequence_length, config.hidden_size), get_torch_mask(config, [data[mmask_tensor], data[smask_tensor]]) ] # ------------------- PyTorch ------------------------- torch_model = BertAttention( TorchBertConfig(config.vocab_length, config.hidden_size, config.num_layers, config.attention_heads, attention_bias=config.attention_bias, layer_norm_eps=config.layer_norm_eps)) # Turn off dropout torch_model.eval() mapping = TORCH_TO_ONNX_SPLIT_QKV if split_qkv else TORCH_TO_ONNX copy_weights_to_torch(torch_model, proto, mapping, transform=get_transform(split_qkv, config.hidden_size)) optim = torch.optim.SGD(torch_model.parameters(), 0.01) for _ in range(num_reps): torch_output = torch_model( *[torch.from_numpy(t).float() for t in inputs])[0] torch_loss = l1_lambda * torch.norm(torch_output, 1) torch_loss.backward() optim.step() optim.zero_grad() check_tensors([torch_output.detach().numpy()], outputs, margin=6e-07) check_model(torch_model, post_proto, mapping, transform=get_transform(split_qkv, config.hidden_size), margin=2e-7)