def testORTDeterministicCompute(seed, device): # Common setup optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({ 'debug' : { 'deterministic_compute': True }, 'device' : { 'id' : device, 'mem_limit' : 10*1024*1024 } }) # Setup for the first ORTTRainer run torch.manual_seed(seed) set_seed(seed) model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device) first_trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) data, targets = batcher_fn(train_data, 0) _ = first_trainer.train_step(data, targets) assert first_trainer._onnx_model is not None # Setup for the second ORTTRainer run torch.manual_seed(seed) set_seed(seed) model, _, _, _, _, _, _ = _load_pytorch_transformer_model(device) second_trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) _ = second_trainer.train_step(data, targets) assert second_trainer._onnx_model is not None # Compare two different instances with identical setup assert id(first_trainer._onnx_model) != id(second_trainer._onnx_model) _test_helpers.assert_onnx_weights(first_trainer, second_trainer)
def testToyBERTSaveAsONNX(): device = 'cuda' onnx_file_name = '_____temp_toy_bert_onnx_model.onnx' if os.path.exists(onnx_file_name): os.remove(onnx_file_name) assert not os.path.exists(onnx_file_name) # Load trainer model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.LambConfig() opts = orttrainer.ORTTrainerOptions({ 'debug': { 'deterministic_compute': True }, 'device': { 'id': device, }, }) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) trainer.save_as_onnx(onnx_file_name) assert os.path.exists(onnx_file_name) with open(onnx_file_name, "rb") as f: bin_str = f.read() reload_onnx_model = onnx.load_model_from_string(bin_str) os.remove(onnx_file_name) # Create a new trainer from persisted ONNX model and compare with original ONNX model trainer_from_onnx = orttrainer.ORTTrainer(reload_onnx_model, model_desc, optim_config, options=opts) assert trainer_from_onnx._onnx_model is not None assert (id(trainer_from_onnx._onnx_model) != id(trainer._onnx_model)) for initializer, loaded_initializer in zip( trainer._onnx_model.graph.initializer, trainer_from_onnx._onnx_model.graph.initializer): assert initializer.name == loaded_initializer.name assert (onnx.helper.printable_graph( trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph( trainer._onnx_model.graph)) _test_helpers.assert_onnx_weights(trainer, trainer_from_onnx)