def popart_result_and_model(popart_config, is_bwd=False): builder = popart.Builder() popart_model = Bert(popart_config, builder=builder) input_info = popart.TensorInfo(popart_config.popart_dtype, [ popart_config.batch_size * popart_config.sequence_length, popart_config.hidden_size ]) input_tensor = builder.addInputTensor(input_info) data = { input_tensor: np.random.normal(0, 0.02, input_info.shape()).astype(popart_config.dtype) } output = popart_model.feed_forward(input_tensor) proto = builder.getModelProto() if is_bwd: l1_lambda = 0.1 l1 = popart.L1Loss(output, "l1LossVal", l1_lambda) optimizer = popart.ConstSGD(0.01) outputs, post_proto = run_py(proto, data, (output, l1.output(0)), loss=l1, optimizer=optimizer) else: outputs, post_proto = run_py(proto, data, output) return data[input_tensor], outputs, proto, post_proto
def session(train=False, skip_execution=False, include_patterns=True, splits=1, outline=False, optim="Sgd"): proto, data, x, loss = model(splits=splits) patterns = popart.Patterns() patterns.enablePattern("TiedGatherPattern", include_patterns) patterns.enablePattern("TiedGatherAccumulatePattern", include_patterns) user_options = { "enableOutlining": outline, "enableGradientAccumulation": True, "accumulationFactor": 2, "accumulationAndReplicationReductionType": popart.ReductionType.Mean, "meanAccumulationAndReplicationReductionStrategy": popart.MeanReductionStrategy.Running } if optim == "Lamb": optimizer = popart.Adam({ "defaultLearningRate": (0.1, False), "defaultWeightDecay": (0.1, True), "defaultBeta1": (0.1, True), "defaultBeta2": (0.1, True), "lossScaling": (20, True), }, mode=popart.AdamMode.LambNoBias) # NoBias to increase the error of incorrect gradients user_options["optimizerStateTensorLocationSettings"] = popart.TensorLocationSettings( popart.TensorLocation( popart.TensorStorage.OffChip, popart.ReplicatedTensorSharding.On), 0, 0) user_options["enableReplicatedGraphs"] = True user_options["replicatedGraphCount"] = 2 ipus = 2 else: optimizer = popart.SGD({ "defaultLearningRate": (0.1, True), "defaultMomentum": (0.9, True), "defaultDampening": (0, True), # 0 dampening to increase the error of incorrect gradients "lossScaling": (20, True)}) ipus = 1 if train: return run_py( proto, data=data, outputs=x, loss=loss, optimizer=optimizer, patterns=patterns, user_options=user_options, skip_execution=skip_execution) else: return run_py( proto, data=data, outputs=x, patterns=patterns, user_options={ "enableOutlining": outline, "constantWeights": False }, skip_execution=skip_execution)
def popart_result_and_model(config, weight_transposed, is_bwd=False): """Run popart model based on config. Args: config (BertConfig): Popart config. weight_transposed: Construct embedding dict transposed. is_bwd (bool, optional): Construct training graph if True, else inference graph. Defaults to False. Returns: Tuple: Gathered numpy data, outputs from model, proto, post_proto """ user_options = {} popart_model = Bert(config) builder = popart_model.builder indices_len = config.micro_batch_size * config.sequence_length sequence_info = popart.TensorInfo("UINT32", [indices_len]) indices = builder.addInputTensor(sequence_info) data = { indices: np.random.randint(0, config.vocab_length, (indices_len)).astype(np.uint32) } output = popart_model.word_embedding_serialized(indices, num_splits) if is_bwd: l1_loss = popart_model.builder.aiGraphcore.l1loss( [output], 0.1, debugContext="l1LossVal", reduction=popart.ReductionType.Sum) proto = builder.getModelProto() optimizer = popart.ConstSGD(0.01) outputs, post_proto = run_py(proto, data, (output, l1_loss), loss=l1_loss, optimizer=optimizer, user_options=user_options) else: proto = builder.getModelProto() outputs, post_proto = run_py(proto, data, output, user_options=user_options) return [data[indices]], outputs, proto, post_proto
def fwd_graph(popart_model, torch_model, mapping=None, transform=None): # ------------------- PopART -------------------- config = popart_model.config builder = popart_model.builder sequence_info = popart.TensorInfo( "INT32", [config.batch_size * config.sequence_length]) indices = builder.addInputTensor(sequence_info) positions = builder.addInputTensor(sequence_info) segments = builder.addInputTensor(sequence_info) data = { indices: np.random.randint(0, config.vocab_length, (config.batch_size * config.sequence_length)).astype( np.int32), positions: np.random.randint(0, config.sequence_length, (config.batch_size * config.sequence_length)).astype( np.int32), segments: np.random.randint(0, 2, (config.batch_size * config.sequence_length)).astype( np.int32) } output = popart_model.build_graph(indices, positions, segments) proto = builder.getModelProto() outputs, post_proto = run_py( proto, data, output, ipus=math.ceil(config.num_layers / config.layers_per_ipu) + popart_model.layer_offset) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = { "input_ids": data[indices].reshape(config.batch_size, config.sequence_length), "position_ids": data[positions].reshape(config.batch_size, config.sequence_length), "token_type_ids": data[segments].reshape(config.batch_size, config.sequence_length) } torch_to_onnx = get_mapping(config, init=mapping) transform_weights = get_transform(config, init=transform) # ------------------- PyTorch ------------------------- # Turn off dropout torch_model.eval() copy_weights_to_torch(torch_model, proto, torch_to_onnx, transform_weights) torch_outputs = run_fwd_model(inputs, torch_model) check_tensors(torch_outputs, outputs)
def session(splits=1): proto, data, x, loss = model(splits) user_options = { "enableOutlining": False, "enableGradientAccumulation": True, "accumulationFactor": 2, "optimizerStateTensorLocationSettings": popart.TensorLocationSettings( popart.TensorStorage.OffChip, 0) } optimizer = popart.Adam({ "defaultLearningRate": (0.1, True), "defaultBeta1": (0.1, True), "defaultBeta2": (0.1, True) }, mode=popart.AdamMode.LambNoBias) # NoBias to increase the error of incorrect gradients return run_py( proto, data=data, outputs=x, loss=loss, optimizer=optimizer, patterns=popart.Patterns(), user_options=user_options, skip_execution=False)
def test_outline_dropout_pattern_one(custom_ops): ''' Tests that the OutlineDropoutPattern successfully outlines all 3 dropouts (fwd, bwd) into a single subgraph Expected IR Graph (excluding adds etc) fwd... x = add(data0, weight0) 0_seed = seedModify(seed, 0) x = call_0(x, 0_seed) 1_seed = seedModify(seed, 1) x = call_0(x, 1_seed) 2_seed = seedModify(seed, 2) x = call_0(x, 2_seed) bwd... x = call_0(x, 0_seed) x = call_0(x, 1_seed) x = call_0(x, 2_seed) where call_0(x, seed) = dropout(x, seed) ''' input_data = np.random.rand(2, 2).astype(np.float32) builder = popart.Builder() d0 = builder.addInputTensor(popart.TensorInfo('FLOAT', input_data.shape), 'data0') w0 = builder.addInitializedInputTensor(input_data, 'weight0') x = builder.aiOnnx.add([d0, w0]) x = builder.aiOnnx.dropout([x], 1)[0] x = builder.aiOnnx.dropout([x], 1)[0] x = builder.aiOnnx.dropout([x], 1)[0] session = run_py(builder.getModelProto(), data={d0: input_data}, outputs=x, loss=popart.L1Loss(x, 'loss', 0.1), optimizer=popart.ConstSGD(0.1), patterns=popart.Patterns( ["OutlineDropoutPattern", "PostNRepl"]), user_options={"outlineThreshold": -1}, skip_execution=True) ir = json.loads(session._serializeIr(popart.IrSerializationFormat.JSON)) # There should only be a main graph and 1 subgraph containing dropout assert len(ir.keys()) == 2 ops = [o["type"] for o in ir["_subgraph(0)"]] assert "Dropout" in ops ops = [o["type"] for o in ir["maingraph"]] # Should only be 1 seed modify per dropout assert len(list(filter(lambda op: op == "SeedModify", ops))) == 6 # The bwd and fwd should be outlined together assert len(list(filter(lambda op: op == "Call", ops))) == 6
def popart_result_and_model(popart_config, weight_decay=0, lr=0, l1_lambda=0): builder = popart.Builder() popart_model = Bert(popart_config, builder=builder) input_info = popart.TensorInfo(popart_config.popart_dtype, [ popart_config.batch_size * popart_config.sequence_length, popart_config.hidden_size ]) input_tensor = builder.addInputTensor(input_info) data = { input_tensor: np.random.normal(0, 0.02, input_info.shape()).astype(popart_config.dtype) } output = popart_model.feed_forward(input_tensor) proto = builder.getModelProto() l1 = popart.L1Loss(output, "l1LossVal", l1_lambda) iteration = MockIteration() args = MockArgs(lr, weight_decay) optimizer_factory = BaseOptimizerFactory(args, iteration, popart_model.tensors) optimizer = optimizer_factory.create() outputs, post_proto = run_py(proto, data, (output, l1.output(0)), loss=l1, optimizer=optimizer) return data[input_tensor], outputs, proto, post_proto
def _test_cmd(name: str, tmp_path, argv: List[str], raises: bool): assert_smart_equals_ref( f'test_main.{name}', run_py(tmp_path=tmp_path, argv=['-m', 'traceback_with_variables.main'] + argv, raises=raises).replace('[script-arg [script-arg ...]]', '[script-arg ...]') # python 3.9+ )
def popart_result_and_model(popart_config, is_bwd=False, momentum=0.0): popart_model = Bert(popart_config) input_info = popart.TensorInfo(popart_config.popart_dtype, [ popart_config.micro_batch_size * popart_config.sequence_length, popart_config.hidden_size ]) input_tensor = popart_model.builder.addInputTensor(input_info) data = { input_tensor: np.random.normal(0, 0.02, input_info.shape()).astype(popart_config.dtype) } output = popart_model.feed_forward(input_tensor) if is_bwd: l1 = popart_model.builder.aiGraphcore.l1loss( [output], 0.1, debugContext="l1LossVal", reduction=popart.ReductionType.Sum) proto = popart_model.builder.getModelProto() if momentum > 0.0: optimizer = popart.SGD({ "defaultLearningRate": (lr, False), "defaultMomentum": (momentum, False), "defaultWeightDecay": (0.0, False) }) else: optimizer = popart.ConstSGD(lr) outputs, post_proto = run_py(proto, data, (output, l1), loss=l1, optimizer=optimizer, num_reps=num_reps_bwd) else: proto = popart_model.builder.getModelProto() outputs, post_proto = run_py(proto, data, output) return data[input_tensor], outputs, proto, post_proto
def run_models(config, proto, indices, positions, segments, output, popart_model, torch_model): onnx_proto = onnx.load_model_from_string(proto) check_model(torch_model, onnx_proto, get_mapping(config), get_transform(config)) # Run the models popart_inputs = { indices: np.random.randint(0, config.vocab_length, (config.batch_size * config.sequence_length)).astype( np.uint32), positions: np.random.randint( 0, config.sequence_length, (config.batch_size * config.sequence_length), ).astype(np.uint32), segments: np.random.randint( 0, 2, (config.batch_size * config.sequence_length), ).astype(np.uint32), } popart_outputs, post_proto = run_py( proto, popart_inputs, output, ipus=popart_model.total_ipus, ) torch_inputs = { "input_ids": popart_inputs[indices].reshape(config.batch_size, config.sequence_length), "position_ids": popart_inputs[positions].reshape(config.batch_size, config.sequence_length), "token_type_ids": popart_inputs[segments].reshape(config.batch_size, config.sequence_length), } torch_model.eval() torch_outputs = run_fwd_model(torch_inputs, torch_model) check_model(torch_model, post_proto, get_mapping(config), get_transform(config)) check_tensors(torch_outputs, popart_outputs) print("Test succeeded")
def _test_cmd(name: str, tmp_path, argv: List[str], raises: bool): assert_smart_equals_ref( f'test_main.{name}', re.sub( r']\s+', ']\n', re.sub( r'\[([^-][^\s]+) \[[^\s]+ ...]]', r'[\1 ...]', run_py( # for python3.9 tmp_path=tmp_path, argv=['-m', 'traceback_with_variables.main'] + argv, raises=raises))))
def session(train=False, skip_execution=False, include_patterns=True, splits=1, outline=False): proto, data, x, loss = model(splits=splits) # Required patterns = [ "MatMulOp", "MatMulLhsGradOp", "MatMulRhsGradOp", "OpToIdentity", "PreUniRepl" ] if include_patterns: patterns += ["TiedGatherPattern", "TiedGatherGradPattern"] if train: return run_py( proto, data=data, outputs=x, loss=loss, optimizer=popart.SGD({ "defaultLearningRate": (0.1, True), "defaultMomentum": (0.9, True), "defaultDampening": (0, True) }), # 0 dampening to increase the error of incorrect gradients patterns=popart.Patterns(patterns), user_options={"enableOutlining": outline}, skip_execution=skip_execution) else: return run_py(proto, data=data, outputs=x, patterns=popart.Patterns(patterns), user_options={ "enableOutlining": outline, "constantWeights": False }, skip_execution=skip_execution)
def popart_result_and_model(popart_config, weight_decay=0.0, lr=0.0, l1_lambda=0.0): popart_model = Bert(popart_config) builder = popart_model.builder input_info = popart.TensorInfo(popart_config.popart_dtype, [ popart_config.micro_batch_size * popart_config.sequence_length, popart_config.hidden_size ]) input_tensor = builder.addInputTensor(input_info) data = { input_tensor: np.random.normal(0, 0.02, input_info.shape()).astype(popart_config.dtype) } output = popart_model.feed_forward(input_tensor) l1 = builder.aiGraphcore.l1loss([output], l1_lambda, debugContext="l1LossVal", reduction=popart.ReductionType.Sum) proto = builder.getModelProto() iteration = MockIteration() args = MockArgs("SGD", lr, weight_decay) optimizer_factory = BaseOptimizerFactory(args, iteration, popart_model.tensors) optimizer = optimizer_factory.create() outputs, post_proto = run_py(proto, data, (output, l1), loss=l1, optimizer=optimizer) return data[input_tensor], outputs, proto, post_proto
def session(skip_execution=False, include_patterns=True, momentum=False): proto, data, x = model() # Required patterns = [ "MatMulOp", "MatMulLhsGradOp", "MatMulRhsGradOp", "OpToIdentity", "PreUniRepl", "PostNRepl", "InPlace" ] if include_patterns: patterns += ["InplaceWorkaroundPattern"] optimizer = popart.ConstSGD(0.1) if momentum: optimizer = popart.SGD({ "defaultLearningRate": (0.1, True), "defaultMomentum": (0.9, True) }) return run_py(proto, data=data, outputs=x, loss=popart.L1Loss(x, 'loss', 0.1), optimizer=optimizer, patterns=popart.Patterns(patterns), user_options={"enableOutlining": False}, skip_execution=skip_execution)
def bwd_graph(popart_model, torch_model, popart_loss_fn, torch_loss_fn, mapping=None, transform=None): np.random.seed(1984) random.seed(1984) torch.manual_seed(1984) # ------------------- PopART -------------------- config = popart_model.config builder = popart_model.builder sequence_info = popart.TensorInfo( "UINT32", [config.batch_size * config.sequence_length]) indices = builder.addInputTensor(sequence_info) positions = builder.addInputTensor(sequence_info) segments = builder.addInputTensor(sequence_info) data = { indices: np.random.randint( 0, config.vocab_length, (config.batch_size * config.sequence_length)).astype(np.uint32), positions: np.random.randint( 0, config.sequence_length, (config.batch_size * config.sequence_length)).astype(np.uint32), segments: np.random.randint( 0, 2, (config.batch_size * config.sequence_length)).astype(np.uint32) } output = popart_model.build_graph(indices, positions, segments) proto = builder.getModelProto() losses = popart_loss_fn(output) optimizer = popart.ConstSGD(0.01) outputs, post_proto = run_py( proto, data, output, loss=losses, optimizer=optimizer, ipus=math.ceil(config.num_layers / config.layers_per_ipu) + popart_model.layer_offset) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = { "input_ids": data[indices].reshape(config.batch_size, config.sequence_length).astype(np.int32), "position_ids": data[positions].reshape(config.batch_size, config.sequence_length).astype(np.int32), "token_type_ids": data[segments].reshape(config.batch_size, config.sequence_length).astype(np.int32) } torch_to_onnx = get_mapping(config, init=mapping) transform_weights = get_transform(config, init=transform) # ------------------- PyTorch ------------------------- # Turn off dropout torch_model.eval() copy_weights_to_torch(torch_model, proto, torch_to_onnx, transform_weights) optim = torch.optim.SGD(torch_model.parameters(), 0.01, weight_decay=0.0, momentum=0.0) torch_outputs = torch_model( **{k: torch.from_numpy(t).long() for k, t in inputs.items()}) torch_loss = torch_loss_fn(torch_outputs) torch_loss.backward() optim.step() check_tensors([output.detach().numpy() for output in torch_outputs], outputs) check_model(torch_model, post_proto, torch_to_onnx, transform_weights, margin=6e-7)
def popart_result_and_model(config, mode, weight_transposed, is_bwd=False): """Run popart model based on config. Args: config (BertConfig): Popart config. weight_transposed: Construct embedding dict transposed. is_bwd (bool, optional): Construct training graph if True, else inference graph. Defaults to False. Returns: Tuple: Gathered numpy data, outputs from model, proto, post_proto """ scope_provider = ScopeProvider() user_options = {} if mode == ExecutionMode.PHASED: builder = popart.Builder() indices_len = config.batch_size * config.sequence_length sequence_info = popart.TensorInfo("UINT32", [indices_len]) indices = builder.addInputTensor(sequence_info) data = {indices: np.random.randint(0, config.vocab_length, (indices_len)).astype(np.uint32)} popart_model = EmbeddingSerialised(scope_provider.get_scope('Token'), input_dim=config.vocab_length, output_dim=config.hidden_size, num_splits=config.embedding_serialization_vocab_steps, custom=True, dtype=config.dtype, detach=not config.update_embedding_dict, weight_transposed=weight_transposed, builder=builder, scope_provider=scope_provider) user_options = { "batchSerializationFactor": 1, "executionPhases": popart_model.total_execution_phases } output = popart_model(indices) else: popart_model = get_model(config, mode, block="embedding", initializers={}) builder = popart_model.builder indices_len = config.batch_size * config.sequence_length sequence_info = popart.TensorInfo("UINT32", [indices_len]) indices = builder.addInputTensor(sequence_info) data = {indices: np.random.randint(0, config.vocab_length, (indices_len)).astype(np.uint32)} output = popart_model.word_embedding_serialized(indices, num_splits) if is_bwd: l1_lambda = 0.1 if mode == ExecutionMode.PHASED: loss_scope = scope_provider.get_scope('Loss', 'prev') with popart_model.scope_provider(popart_model.builder, loss_scope): l1_loss = popart_model.builder.aiGraphcore.l1loss([output], l1_lambda, debugPrefix="l1LossVal", reduction=popart.ReductionType.Sum) else: l1_loss = popart_model.builder.aiGraphcore.l1loss([output], l1_lambda, debugPrefix="l1LossVal", reduction=popart.ReductionType.Sum) proto = builder.getModelProto() optimizer = popart.ConstSGD(0.01) outputs, post_proto = run_py(proto, data, (output, l1_loss), loss=l1_loss, optimizer=optimizer, user_options=user_options, execution_mode=mode) else: proto = builder.getModelProto() outputs, post_proto = run_py(proto, data, output, user_options=user_options, execution_mode=mode) return [data[indices]], outputs, proto, post_proto
def test_embedding_fwd(custom_ops): # ------------------- PopART -------------------- config = BertConfig(task="SQUAD", vocab_length=9728, micro_batch_size=1, hidden_size=768, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, inference=True) popart_model = Bert(config) sequence_info = popart.TensorInfo( "UINT32", [config.micro_batch_size * config.sequence_length]) indices = popart_model.builder.addInputTensor(sequence_info) positions = popart_model.builder.addInputTensor(sequence_info) segments = popart_model.builder.addInputTensor(sequence_info) data = { indices: np.random.randint( 0, config.vocab_length, (config.micro_batch_size * config.sequence_length)).astype( np.uint32), positions: np.random.randint( 0, config.max_positional_length, (config.micro_batch_size * config.sequence_length)).astype( np.uint32), segments: np.random.randint( 0, 2, (config.micro_batch_size * config.sequence_length)).astype( np.uint32) } user_options = {"enableStochasticRounding": True} with popart_model.builder.nameScope("Embedding"): output = popart_model.embedding(indices, positions, segments) proto = popart_model.builder.getModelProto() outputs, post_proto = run_py(proto, data, output, user_options=user_options) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = [ data[t].reshape(config.micro_batch_size, config.sequence_length).astype(np.int32) for t in [indices, positions, segments] ] # ------------------- PyTorch ------------------------- torch_model = BertEmbeddings( TorchBertConfig(config.vocab_length, config.hidden_size, max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps)) torch_model.eval() copy_weights_to_torch(torch_model, proto, TORCH_TO_ONNX, {}) torch_outputs = run_fwd_model(inputs, torch_model) check_tensors(torch_outputs, outputs, margin=5e-7)
def test_embedding_bwd(custom_ops): # ------------------- PopART -------------------- config = BertConfig(task="SQUAD", vocab_length=9728, micro_batch_size=1, hidden_size=768, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, update_embedding_dict=True) popart_model = Bert(config) # Prevent virtualGraph attributes being added to the ops sequence_info = popart.TensorInfo( "UINT32", [config.micro_batch_size * config.sequence_length]) indices = popart_model.builder.addInputTensor(sequence_info) positions = popart_model.builder.addInputTensor(sequence_info) segments = popart_model.builder.addInputTensor(sequence_info) data = { indices: np.random.randint( 0, config.vocab_length, (config.micro_batch_size * config.sequence_length)).astype( np.uint32), positions: np.random.randint( 0, config.max_positional_length, (config.micro_batch_size * config.sequence_length)).astype( np.uint32), segments: np.random.randint( 0, 2, (config.micro_batch_size * config.sequence_length)).astype( np.uint32) } optimizer = popart.ConstSGD(0.01) l1_lambda = 0.1 with popart_model.builder.nameScope("Embedding"): output = popart_model.embedding(indices, positions, segments) l1 = popart_model.builder.aiGraphcore.l1loss( [output], l1_lambda, debugContext="l1LossVal", reduction=popart.ReductionType.Sum) num_reps = 5 proto = popart_model.builder.getModelProto() outputs, post_proto = run_py(proto, data, output, ipus=1, loss=l1, num_reps=num_reps, optimizer=optimizer) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = [ data[t].reshape(config.micro_batch_size, config.sequence_length).astype(np.int32) for t in [indices, positions, segments] ] # ------------------- PyTorch ------------------------- torch_model = BertEmbeddings( TorchBertConfig(config.vocab_length, config.hidden_size, max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, update_embedding_dict=config.update_embedding_dict)) # Turn off dropout torch_model.eval() copy_weights_to_torch(torch_model, proto, TORCH_TO_ONNX, {}) optim = torch.optim.SGD(torch_model.parameters(), 0.01) for _ in range(num_reps): torch_output = torch_model( *[torch.from_numpy(t).long() for t in inputs]) torch_loss = l1_lambda * torch.norm(torch_output, 1) torch_loss.backward() optim.step() optim.zero_grad() torch_outputs = [torch_output.detach().numpy()] check_tensors(torch_outputs, outputs, margin=7e-6) check_model(torch_model, post_proto, TORCH_TO_ONNX, {}, margin=7e-06)
def test_load_from_chkpt(config_path, chkpt_path, custom_ops): """ Compare the model loaded into our popart model against the modified PyTorch model: - Load tf weights into BERT using torch impl -> run fwd model - Load tf weights into BERT using popart impl -> run fwd model - Compare output tensors """ config = load_bert_config_tf(config_path) builder = popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }) # Load Torch version torch_model = TorchModel( TorchBertConfig( config.vocab_length, config.hidden_size, num_hidden_layers=config.num_layers, num_attention_heads=config.attention_heads, intermediate_size=config.ff_size, hidden_act="relu", max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, mask_tokens=config.mask_tokens, )) torch_model.eval() torch_model = load_tf_weights_in_bert(torch_model, config, chkpt_path) # Load Popart model sequence_info = popart.TensorInfo( "INT32", [config.batch_size * config.sequence_length]) indices = builder.addInputTensor(sequence_info) positions = builder.addInputTensor(sequence_info) popart_model, proto, output = load_from_tf(chkpt_path, True, config, indices, positions, builder=builder) # Run the models popart_inputs = { indices: np.random.randint(0, config.vocab_length, (config.batch_size * config.sequence_length)).astype( np.int32), positions: np.random.randint( 0, config.sequence_length, (config.batch_size * config.sequence_length), ).astype(np.int32), } torch_inputs = { "input_ids": popart_inputs[indices].reshape(config.batch_size, config.sequence_length), "position_ids": popart_inputs[positions].reshape(config.batch_size, config.sequence_length), } torch_outputs = run_fwd_model(torch_inputs, torch_model) popart_outputs, post_proto = run_py( proto, popart_inputs, output, ipus=math.ceil(config.num_layers / config.layers_per_ipu) + 1, ) check_tensors(torch_outputs, popart_outputs) print("Test succeeded")
def test_embedding_fwd(custom_ops, mode, batch_size, batch_serialization_factor, embedding_serialization_vocab_steps): # ------------------- PopART -------------------- config = BertConfig( task="SQUAD", vocab_length=9728, batch_size=batch_size, hidden_size=768, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, inference=True, embedding_serialization_vocab_steps=embedding_serialization_vocab_steps ) popart_model = get_model(config, mode, 'embedding') sequence_info = popart.TensorInfo( "UINT32", [config.batch_size * config.sequence_length]) indices = popart_model.builder.addInputTensor(sequence_info) positions = popart_model.builder.addInputTensor(sequence_info) segments = popart_model.builder.addInputTensor(sequence_info) data = { indices: np.random.randint(0, config.vocab_length, (config.batch_size * config.sequence_length)).astype( np.uint32), positions: np.random.randint(0, config.max_positional_length, (config.batch_size * config.sequence_length)).astype( np.uint32), segments: np.random.randint(0, 2, (config.batch_size * config.sequence_length)).astype( np.uint32) } user_options = {} if mode == ExecutionMode.PHASED: user_options = { "batchSerializationFactor": batch_serialization_factor, "executionPhases": popart_model.total_execution_phases } output = popart_model(indices, positions, segments) else: user_options = {"enableStochasticRounding": True} with popart_model.builder.nameScope("Embedding"): output = popart_model.embedding(indices, positions, segments) proto = popart_model.builder.getModelProto() outputs, post_proto = run_py(proto, data, output, user_options=user_options, execution_mode=mode) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = [ data[t].reshape(config.batch_size, config.sequence_length).astype(np.int32) for t in [indices, positions, segments] ] # ------------------- PyTorch ------------------------- torch_model = BertEmbeddings( TorchBertConfig(config.vocab_length, config.hidden_size, max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps)) torch_model.eval() expanded_name_map, remapped_transform_map = expand_torch_to_onnx_map( TORCH_TO_ONNX[mode], config, mode) copy_weights_to_torch(torch_model, proto, expanded_name_map, remapped_transform_map) torch_outputs = run_fwd_model(inputs, torch_model) check_tensors(torch_outputs, outputs, margin=5e-7)
def test_outline_dropout_pattern_many(custom_ops): ''' Tests that the OutlineDropoutPattern successfully outlines all 3 dropouts (fwd, bwd) into a 3 different subgraphs. Expected IR Graph (excluding adds etc) fwd... x = add(data0, weight0) 0_seed = seedModify(seed, 0) x = call_0(x, 0_seed) 1_seed = seedModify(seed, 1) x = call_1(x, 1_seed) 2_seed = seedModify(seed, 2) x = call_2(x, 2_seed) bwd... x = call_2(x, 0_seed) x = call_1(x, 1_seed) x = call_0(x, 2_seed) where call_n(x, seed) = dropout(x, seed) ''' input_data = np.random.rand(2, 2).astype(np.float32) builder = popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }) d0 = builder.addInputTensor(popart.TensorInfo('FLOAT', input_data.shape), 'data0') w0 = builder.addInitializedInputTensor(input_data, 'weight0') x = builder.aiOnnx.add([d0, w0]) x = builder.aiOnnx.dropout([x], 1)[0] # Different subgraph as it has a different ratio x = builder.aiOnnx.dropout([x], 1, ratio=0.8)[0] # Different subgraph as it has a different input shape x = builder.aiOnnx.slice([x], axes=[1], starts=[0], ends=[1]) x = builder.aiOnnx.dropout([x], 1)[0] loss = builder.aiGraphcore.l1loss([x], 0.1, debugPrefix='loss') patterns = popart.Patterns(popart.PatternsLevel.Minimal) patterns.enablePattern("OutlineDropoutPattern", True) patterns.enablePattern("PostNRepl", True) session = run_py(builder.getModelProto(), data={d0: input_data}, outputs=x, loss=loss, optimizer=popart.ConstSGD(0.1), patterns=patterns, user_options={"outlineThreshold": -np.inf}, skip_execution=True) ir = json.loads(session._serializeIr(popart.IrSerializationFormat.JSON)) # There should only be a main graph and 3 subgraph containing dropout assert len(ir.keys()) == 4 ops = [o["type"] for i in range(3) for o in ir[f"_subgraph({i})"]] assert "Dropout" in ops ops = [o["type"] for o in ir["maingraph"]] # Should only be 1 seed modify per dropout assert len(list(filter(lambda op: op == "SeedModify", ops))) == 6 # The bwd and fwd should be outlined together assert len(list(filter(lambda op: op == "Call", ops))) == 6
def test_embedding_bwd(custom_ops): l1_lambda = 0.1 # ------------------- PopART -------------------- builder = popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }) config = BertConfig(vocab_length=9728, batch_size=1, hidden_size=768, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, custom_ops=['gather']) popart_model = Bert(config, builder=builder) # Prevent virtualGraph attributes being added to the ops. popart_model.embedding_scope = popart_model.device_scope(None, None) popart_model.embedding_split_scope = popart_model.embedding_scope sequence_info = popart.TensorInfo( "UINT32", [config.batch_size * config.sequence_length]) indices = builder.addInputTensor(sequence_info) positions = builder.addInputTensor(sequence_info) segments = builder.addInputTensor(sequence_info) data = { indices: np.random.randint(0, config.vocab_length, (config.batch_size * config.sequence_length)).astype( np.uint32), positions: np.random.randint(0, config.max_positional_length, (config.batch_size * config.sequence_length)).astype( np.uint32), segments: np.random.randint(0, 2, (config.batch_size * config.sequence_length)).astype( np.uint32) } output = popart_model.embedding(indices, positions, segments) proto = builder.getModelProto() l1 = popart.L1Loss(output, "l1LossVal", l1_lambda) optimizer = popart.ConstSGD(0.01) outputs, post_proto = run_py( proto, data, output, loss=l1, optimizer=optimizer, user_options={"enableStochasticRounding": True}) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = [ data[t].reshape(config.batch_size, config.sequence_length).astype(np.int32) for t in [indices, positions, segments] ] torch_to_onnx = { "word_embeddings.weight": "Embedding_Dict", "position_embeddings.weight": "Positional_Dict", "token_type_embeddings.weight": "Segment_Dict", "LayerNorm.weight": "Gamma", "LayerNorm.bias": "Beta" } transposed_weights = { "word_embeddings.weight": np.transpose, "position_embeddings.weight": np.transpose, } # ------------------- PyTorch ------------------------- torch_model = BertEmbeddings( TorchBertConfig(config.vocab_length, config.hidden_size, max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps)) # Turn off dropout torch_model.eval() copy_weights_to_torch(torch_model, proto, torch_to_onnx, transform=transposed_weights) optim = torch.optim.SGD(torch_model.parameters(), 0.01, weight_decay=0.0, momentum=0.0) torch_output = torch_model(*[torch.from_numpy(t).long() for t in inputs]) torch_loss = l1_lambda * torch.norm(torch_output, 1) torch_loss.backward() optim.step() torch_outputs = [torch_output.detach().numpy()] check_tensors(torch_outputs, outputs) check_model(torch_model, post_proto, torch_to_onnx, transform=transposed_weights)
def test_embedding_projection_fwd(custom_ops): # ------------------- PopART -------------------- builder = popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }) config = BertConfig(vocab_length=9728, embedding_serialization_vocab_steps=4, micro_batch_size=1, hidden_size=768, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, no_cls_layer=False, inference=True) popart_model = Bert(config, builder=builder) sequence_info = popart.TensorInfo( "UINT32", [config.micro_batch_size * config.sequence_length]) indices = builder.addInputTensor(sequence_info) data = { indices: np.random.randint(0, config.vocab_length, (config.micro_batch_size * config.sequence_length)).astype( np.uint32) } x = popart_model.gather( indices, config.vocab_length, "Embedding_Dict") x = popart_model.norm(x) x = popart_model.dropout(x) with popart_model.builder.nameScope("CLS"): x = popart_model.lm_prediction_head(x) output = popart_model.projection(x) proto = builder.getModelProto() outputs, post_proto = run_py(proto, data, output) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = [data[indices].reshape( config.micro_batch_size, config.sequence_length).astype(np.int32)] # ------------------- PyTorch ------------------------- torch_model = EmbeddingProjectionModel( TorchBertConfig(config.vocab_length, config.hidden_size, max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, no_cls_layer=config.no_cls_layer)) torch_model.eval() copy_weights_to_torch(torch_model, proto, TORCH_TO_ONNX, TRANSPOSE_WEIGHTS) torch_model.tie_weights() torch_outputs = run_fwd_model(inputs, torch_model) check_tensors(torch_outputs, outputs)
def test_embedding_projection_bwd(custom_ops): l1_lambda = 0.1 # ------------------- PopART -------------------- builder = popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }) config = BertConfig(vocab_length=9728, embedding_serialization_vocab_steps=4, micro_batch_size=1, hidden_size=288, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, no_cls_layer=False, # Currently updating embedding dict with projection is only # available with momentum. And PopART != Pytorch momentum # due to a bootstrapping step on iter 0. update_embedding_dict=False) popart_model = Bert(config, builder=builder) sequence_info = popart.TensorInfo( "UINT32", [config.micro_batch_size * config.sequence_length]) indices = builder.addInputTensor(sequence_info) data = { indices: np.random.randint(0, config.vocab_length, (config.micro_batch_size * config.sequence_length)).astype( np.uint32) } x = popart_model.gather( indices, config.vocab_length, "Embedding_Dict") x = popart_model.norm(x) x = popart_model.dropout(x) with popart_model.device_scope(nameScope="CLS"): x = popart_model.lm_prediction_head(x) output = popart_model.projection(x) l1 = builder.aiGraphcore.l1loss( [output], l1_lambda, debugPrefix="l1LossVal", reduction=popart.ReductionType.Sum) proto = builder.getModelProto() optimizer = popart.ConstSGD(0.01) outputs, post_proto = run_py(proto, data, output, loss=l1, optimizer=optimizer) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = [data[indices].reshape( config.micro_batch_size, config.sequence_length).astype(np.int32)] # ------------------- PyTorch ------------------------- torch_model = EmbeddingProjectionModel( TorchBertConfig(config.vocab_length, config.hidden_size, max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, no_cls_layer=config.no_cls_layer, update_embedding_dict=config.update_embedding_dict)) # Turn off dropout torch_model.eval() copy_weights_to_torch(torch_model, proto, TORCH_TO_ONNX, transform=TRANSPOSE_WEIGHTS) optim = torch.optim.SGD(torch_model.parameters(), 0.01, weight_decay=0.0, momentum=0.0) torch_output = torch_model(*[torch.from_numpy(t).long() for t in inputs]) torch_loss = l1_lambda * torch.norm(torch_output, 1) torch_loss.backward() optim.step() check_tensors([torch_output.detach().numpy()], outputs, margin=1e-5) check_model(torch_model, post_proto, TORCH_TO_ONNX, transform=TRANSPOSE_WEIGHTS)
def embedding_bwd(custom_ops, mode, momentum, batch_size, batch_serialization_factor, embedding_serialization_vocab_steps, vocab_length=9728, hidden_size=768): # ------------------- PopART -------------------- config = BertConfig( task="SQUAD", vocab_length=vocab_length, batch_size=batch_size, hidden_size=hidden_size, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, update_embedding_dict=True, embedding_serialization_vocab_steps=embedding_serialization_vocab_steps ) popart_model = get_model(config, mode, 'embedding') # Prevent virtualGraph attributes being added to the ops sequence_info = popart.TensorInfo( "UINT32", [config.batch_size * config.sequence_length]) indices = popart_model.builder.addInputTensor(sequence_info) positions = popart_model.builder.addInputTensor(sequence_info) segments = popart_model.builder.addInputTensor(sequence_info) data = { indices: np.random.randint(0, config.vocab_length, (config.batch_size * config.sequence_length)).astype( np.uint32), positions: np.random.randint(0, config.max_positional_length, (config.batch_size * config.sequence_length)).astype( np.uint32), segments: np.random.randint(0, 2, (config.batch_size * config.sequence_length)).astype( np.uint32) } if momentum: optimizer = popart.SGD({ "defaultLearningRate": (0.01, True), "defaultMomentum": (momentum, True), "defaultDampening": (0.0, True), "defaultVelocityScaling": (1.0, True), "lossScaling": (1.0, True), "defaultWeightDecay": (0.0, True) }) else: optimizer = popart.ConstSGD(0.01) l1_lambda = 0.1 if mode == ExecutionMode.PHASED: user_options = { "batchSerializationFactor": batch_serialization_factor, "executionPhases": popart_model.total_execution_phases, } output = popart_model(indices, positions, segments) with popart_model.scope_provider(popart_model.builder, popart_model.norm.scope): l1 = popart_model.builder.aiGraphcore.l1loss( [output], l1_lambda, debugPrefix="l1LossVal", reduction=popart.ReductionType.Sum) else: user_options = {"enableStochasticRounding": True} with popart_model.builder.nameScope("Embedding"): output = popart_model.embedding(indices, positions, segments) l1 = popart_model.builder.aiGraphcore.l1loss( [output], l1_lambda, debugPrefix="l1LossVal", reduction=popart.ReductionType.Sum) num_reps = 5 proto = popart_model.builder.getModelProto() outputs, post_proto = run_py(proto, data, output, ipus=1, loss=l1, num_reps=num_reps, optimizer=optimizer, user_options=user_options, execution_mode=mode) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = [ data[t].reshape(config.batch_size, config.sequence_length).astype(np.int32) for t in [indices, positions, segments] ] # ------------------- PyTorch ------------------------- torch_model = BertEmbeddings( TorchBertConfig(config.vocab_length, config.hidden_size, max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps, update_embedding_dict=config.update_embedding_dict)) # Turn off dropout torch_model.eval() expanded_name_map, remapped_transform_map = expand_torch_to_onnx_map( TORCH_TO_ONNX[mode], config, mode) copy_weights_to_torch(torch_model, proto, expanded_name_map, remapped_transform_map) optim = torch.optim.SGD(torch_model.parameters(), 0.01, weight_decay=0.0, dampening=0.0, momentum=momentum) if momentum > 0.: for group in optim.param_groups: for p in group['params']: optim.state[p]['momentum_buffer'] = p.data * 0 optim.state[p]['exp_avg'] = p.data * 0 optim.state[p]['exp_avg_sq'] = p.data * 0 optim.state[p]['step'] = 0 for _ in range(num_reps): torch_output = torch_model( *[torch.from_numpy(t).long() for t in inputs]) torch_loss = l1_lambda * torch.norm(torch_output, 1) torch_loss.backward() optim.step() optim.zero_grad() torch_outputs = [torch_output.detach().numpy()] check_tensors(torch_outputs, outputs, margin=7e-6) expanded_name_map, remapped_transform_map = expand_torch_to_onnx_map( TORCH_TO_ONNX[mode], config, mode) check_model(torch_model, post_proto, expanded_name_map, remapped_transform_map, margin=7e-06)
def test_attention_fwd(mode, micro_batch_size, batch_serialisation_factor, number_attention_splits, attention_bias, split_qkv): # ------------------- PopART -------------------- config = BertConfig(task="PRETRAINING", vocab_length=9728, micro_batch_size=micro_batch_size, hidden_size=768, attention_heads=4, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, no_attn_dropout=True, inference=True, split_qkv=split_qkv, attention_bias=attention_bias, num_attention_splits=number_attention_splits) popart_model = get_model(config, mode, 'attention') input_info = popart.TensorInfo( config.popart_dtype, [config.micro_batch_size * config.sequence_length, config.hidden_size]) input_tensor = popart_model.builder.addInputTensor(input_info) mask_info = popart.TensorInfo( "UINT32", [config.micro_batch_size, config.sequence_length]) mmask_tensor = popart_model.builder.addInputTensor(mask_info) smask_tensor = popart_model.builder.addInputTensor(mask_info) data = { input_tensor: np.random.normal(0, 0.02, input_info.shape()).astype(config.dtype), mmask_tensor: np.random.randint(0, config.mask_tokens + 1, ( config.micro_batch_size, config.sequence_length, )).astype(np.uint32), smask_tensor: np.random.randint(config.mask_tokens, config.sequence_length + 1, ( config.micro_batch_size, config.sequence_length, )).astype(np.uint32) } user_options = {} if mode == ExecutionMode.PHASED: user_options = { "batchSerializationFactor": batch_serialisation_factor, "executionPhases": popart_model.total_execution_phases } output = popart_model(input_tensor, [mmask_tensor, smask_tensor]) else: user_options = {"enableStochasticRounding": True} output = popart_model.attention(input_tensor, [mmask_tensor, smask_tensor]) proto = popart_model.builder.getModelProto() outputs, post_proto = run_py(proto, data, output, user_options=user_options, execution_mode=mode) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = [ data[input_tensor].reshape(config.micro_batch_size, config.sequence_length, config.hidden_size).astype(np.float32), get_torch_mask(config, [data[mmask_tensor], data[smask_tensor]]) ] # ------------------- PyTorch ------------------------- torch_model = BertAttention( TorchBertConfig(config.vocab_length, config.hidden_size, config.num_layers, config.attention_heads, attention_bias=config.attention_bias, layer_norm_eps=config.layer_norm_eps)) # Turn off dropout torch_model.eval() mapping = TORCH_TO_ONNX[mode] if split_qkv: mapping = TORCH_TO_ONNX_SPLIT_QKV[mode] copy_weights_to_torch(torch_model, proto, mapping, transform=get_transform(split_qkv, config.hidden_size)) # Model to test against torch_outputs = run_fwd_model(inputs, torch_model) check_tensors(torch_outputs, outputs)
def fwd_graph(popart_model, torch_model, mapping=None, transform=None, replication_factor=1, replicated_tensor_sharding=False): # ------------------- PopART -------------------- config = popart_model.config builder = popart_model.builder sequence_info = popart.TensorInfo( "UINT32", [config.micro_batch_size * config.sequence_length]) indices = builder.addInputTensor(sequence_info) positions = builder.addInputTensor(sequence_info) segments = builder.addInputTensor(sequence_info) data = { indices: np.random.randint(0, config.vocab_length, (replication_factor, config.micro_batch_size * config.sequence_length)).astype(np.uint32), positions: np.random.randint(0, config.sequence_length, (replication_factor, config.micro_batch_size * config.sequence_length)).astype(np.uint32), segments: np.random.randint(0, 2, (replication_factor, config.micro_batch_size * config.sequence_length)).astype(np.uint32) } output = popart_model.build_graph(indices, positions, segments) ipus = popart_model.total_ipus proto = builder.getModelProto() outputs, _ = run_py(proto, data, output, replication_factor=replication_factor, replicated_tensor_sharding=replicated_tensor_sharding, ipus=ipus) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = { "input_ids": data[indices].reshape(replication_factor * config.micro_batch_size, config.sequence_length).astype(np.int32), "position_ids": data[positions].reshape(replication_factor * config.micro_batch_size, config.sequence_length).astype(np.int32), "token_type_ids": data[segments].reshape(replication_factor * config.micro_batch_size, config.sequence_length).astype(np.int32) } torch_to_onnx = get_mapping(config, init=mapping) transform_weights = get_transform(config, init=transform) # ------------------- PyTorch ------------------------- # Turn off dropout torch_model.eval() copy_weights_to_torch(torch_model, proto, torch_to_onnx, transform_weights) torch_outputs = run_fwd_model(inputs, torch_model) check_tensors(torch_outputs, outputs)
def test_embedding_fwd(custom_ops): # ------------------- PopART -------------------- builder = popart.Builder(opsets={ "ai.onnx": 9, "ai.onnx.ml": 1, "ai.graphcore": 1 }) config = BertConfig(vocab_length=9728, batch_size=1, hidden_size=768, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, custom_ops=['gather'], inference=True) popart_model = Bert(config, builder=builder) # Prevent virtualGraph attributes being added to the ops. popart_model.embedding_scope = popart_model.device_scope(None, None) popart_model.embedding_split_scope = popart_model.embedding_scope sequence_info = popart.TensorInfo( "UINT32", [config.batch_size * config.sequence_length]) indices = builder.addInputTensor(sequence_info) positions = builder.addInputTensor(sequence_info) segments = builder.addInputTensor(sequence_info) data = { indices: np.random.randint(0, config.vocab_length, (config.batch_size * config.sequence_length)).astype( np.uint32), positions: np.random.randint(0, config.max_positional_length, (config.batch_size * config.sequence_length)).astype( np.uint32), segments: np.random.randint(0, 2, (config.batch_size * config.sequence_length)).astype( np.uint32) } # Use the custom embedding for layout output = popart_model.embedding(indices, positions, segments) proto = builder.getModelProto() outputs, post_proto = run_py( proto, data, output, user_options={"enableStochasticRounding": True}) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = [ data[t].reshape(config.batch_size, config.sequence_length).astype(np.int32) for t in [indices, positions, segments] ] torch_to_onnx = { "word_embeddings.weight": "Embedding_Dict", "position_embeddings.weight": "Positional_Dict", "token_type_embeddings.weight": "Segment_Dict", "LayerNorm.weight": "Gamma", "LayerNorm.bias": "Beta" } transposed_weights = { "word_embeddings.weight": np.transpose, "position_embeddings.weight": np.transpose, } # ------------------- PyTorch ------------------------- torch_model = BertEmbeddings( TorchBertConfig(config.vocab_length, config.hidden_size, max_position_embeddings=config.max_positional_length, layer_norm_eps=config.layer_norm_eps)) torch_model.eval() copy_weights_to_torch(torch_model, proto, torch_to_onnx, transposed_weights) torch_outputs = run_fwd_model(inputs, torch_model) check_tensors(torch_outputs, outputs)
def bwd_graph(popart_model, torch_model, popart_loss_fn, torch_loss_fn, mapping=None, transform=None, replication_factor=1, replicated_tensor_sharding=False, opt_type="SGD"): np.random.seed(1984) random.seed(1984) torch.manual_seed(1984) # ------------------- PopART -------------------- config = popart_model.config builder = popart_model.builder sequence_info = popart.TensorInfo( "UINT32", [config.micro_batch_size * config.sequence_length]) indices = builder.addInputTensor(sequence_info) positions = builder.addInputTensor(sequence_info) segments = builder.addInputTensor(sequence_info) data = { indices: np.random.randint(0, config.vocab_length, (replication_factor, config.micro_batch_size * config.sequence_length)).astype(np.uint32), positions: np.random.randint(0, config.sequence_length, (replication_factor, config.micro_batch_size * config.sequence_length)).astype(np.uint32), segments: np.random.randint(0, 2, (replication_factor, config.micro_batch_size * config.sequence_length)).astype(np.uint32) } num_reps = 5 output = popart_model.build_graph(indices, positions, segments) ipus = popart_model.total_ipus loss = popart_loss_fn(output) proto = builder.getModelProto() if opt_type == "SGD": optimizer = popart.ConstSGD(1e-3) elif opt_type == "LAMB": optMap = { "defaultLearningRate": (1e-3, True), "defaultBeta1": (0.9, True), "defaultBeta2": (0.999, True), "defaultWeightDecay": (0.0, True), "maxWeightNorm": (10.0, True), "defaultEps": (1e-8, True), "lossScaling": (1.0, True), } optimizer = popart.Adam(optMap, mode=popart.AdamMode.Lamb) elif opt_type == "LAMB_NO_BIAS": optMap = { "defaultLearningRate": (1, False), "defaultBeta1": (0, False), "defaultBeta2": (0, False), "defaultWeightDecay": (0.0, False), "defaultEps": (1e-8, False), "lossScaling": (1.0, False), } optimizer = popart.Adam(optMap, mode=popart.AdamMode.LambNoBias) else: raise ValueError(f"Unknown opt_type={opt_type}") outputs, post_proto = run_py( proto, data, output, loss=loss, optimizer=optimizer, replication_factor=replication_factor, replicated_tensor_sharding=replicated_tensor_sharding, ipus=ipus, num_reps=num_reps) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = { "input_ids": data[indices].reshape(replication_factor * config.micro_batch_size, config.sequence_length).astype(np.int32), "position_ids": data[positions].reshape(replication_factor * config.micro_batch_size, config.sequence_length).astype(np.int32), "token_type_ids": data[segments].reshape(replication_factor * config.micro_batch_size, config.sequence_length).astype(np.int32) } torch_to_onnx = get_mapping(config, init=mapping) transform_weights = get_transform(config, init=transform) # ------------------- PyTorch ------------------------- # Turn off dropout torch_model.eval() copy_weights_to_torch(torch_model, proto, torch_to_onnx, transform_weights) if opt_type == "SGD": optim = torch.optim.SGD(torch_model.parameters(), 1e-3, weight_decay=0.0, momentum=0.0) elif opt_type == "LAMB": optim = torch_lamb.Lamb(torch_model.parameters(), lr=1e-3, weight_decay=0.0, biasCorrection=True) for _ in range(num_reps): torch_outputs = torch_model( **{k: torch.from_numpy(t).long() for k, t in inputs.items()}) torch_loss = torch_loss_fn(torch_outputs) torch_loss.backward() optim.step() optim.zero_grad() check_tensors([output.detach().numpy() for output in torch_outputs], outputs, margin=1.5e-06) check_model(torch_model, post_proto, torch_to_onnx, transform_weights, margin=5e-5)
def test_attention_bwd(mode, momentum, micro_batch_size, batch_serialisation_factor, number_attention_splits, attention_bias): l1_lambda = 0.1 num_reps = 5 np.random.seed(1984) torch.manual_seed(1984) split_qkv = False # ------------------- PopART -------------------- config = BertConfig(task="PRETRAINING", vocab_length=9728, micro_batch_size=micro_batch_size, hidden_size=768, sequence_length=128, activation_type='relu', popart_dtype="FLOAT", no_dropout=True, no_attn_dropout=True, split_qkv=split_qkv, attention_bias=attention_bias, num_attention_splits=number_attention_splits) popart_model = get_model(config, mode, 'attention') input_info = popart.TensorInfo( config.popart_dtype, [config.micro_batch_size * config.sequence_length, config.hidden_size]) input_tensor = popart_model.builder.addInputTensor(input_info) mask_info = popart.TensorInfo( "UINT32", [config.micro_batch_size, config.sequence_length]) mmask_tensor = popart_model.builder.addInputTensor(mask_info) smask_tensor = popart_model.builder.addInputTensor(mask_info) data = { input_tensor: np.random.normal(0, 0.02, input_info.shape()).astype(config.dtype), mmask_tensor: np.random.randint(0, config.mask_tokens + 1, ( config.micro_batch_size, config.sequence_length, )).astype(np.uint32), smask_tensor: np.random.randint(config.mask_tokens, config.sequence_length + 1, ( config.micro_batch_size, config.sequence_length, )).astype(np.uint32) } user_options = {} if mode == ExecutionMode.PHASED: user_options = { "batchSerializationFactor": batch_serialisation_factor, "executionPhases": popart_model.total_execution_phases } output = popart_model(input_tensor, [mmask_tensor, smask_tensor]) with popart_model.scope_provider(popart_model.builder, popart_model.norm.scope): l1 = popart_model.builder.aiGraphcore.l1loss( [output], l1_lambda, debugPrefix="l1LossVal", reduction=popart.ReductionType.Sum) else: user_options = {} output = popart_model.attention(input_tensor, [mmask_tensor, smask_tensor]) l1 = popart_model.builder.aiGraphcore.l1loss( [output], l1_lambda, reduction=popart.ReductionType.Sum) proto = popart_model.builder.getModelProto() if momentum: optimizer = popart.SGD({ "defaultLearningRate": (0.01, True), "defaultMomentum": (momentum, True) }) else: optimizer = popart.ConstSGD(0.01) outputs, post_proto = run_py(proto, data, (output, l1), loss=l1, optimizer=optimizer, num_reps=num_reps, user_options=user_options, execution_mode=mode) # ----------------- PopART -> PyTorch ---------------- proto = onnx.load_model_from_string(proto) inputs = [ data[input_tensor].reshape(config.micro_batch_size, config.sequence_length, config.hidden_size), get_torch_mask(config, [data[mmask_tensor], data[smask_tensor]]) ] # ------------------- PyTorch ------------------------- torch_model = BertAttention( TorchBertConfig(config.vocab_length, config.hidden_size, config.num_layers, config.attention_heads, attention_bias=config.attention_bias, layer_norm_eps=config.layer_norm_eps)) # Turn off dropout torch_model.eval() mapping = TORCH_TO_ONNX[mode] if split_qkv: mapping = TORCH_TO_ONNX_SPLIT_QKV[mode] copy_weights_to_torch(torch_model, proto, mapping, transform=get_transform(split_qkv, config.hidden_size)) optim = torch.optim.SGD(torch_model.parameters(), 0.01, weight_decay=0.0, momentum=momentum) if momentum: for group in optim.param_groups: for p in group['params']: optim.state[p]['momentum_buffer'] = p.data * 0 optim.state[p]['exp_avg'] = p.data * 0 optim.state[p]['exp_avg_sq'] = p.data * 0 optim.state[p]['step'] = 0 for _ in range(num_reps): torch_output = torch_model( *[torch.from_numpy(t).float() for t in inputs])[0] torch_loss = l1_lambda * torch.norm(torch_output, 1) torch_loss.backward() optim.step() optim.zero_grad() check_tensors([torch_output.detach().numpy()], outputs, margin=6e-07) check_model(torch_model, post_proto, mapping, transform=get_transform(split_qkv, config.hidden_size), margin=2e-7)