Exemple #1
0
def testORTDeterministicCompute(seed, device):
    # Common setup
    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions({
        'debug' : {
            'deterministic_compute': True
        },
        'device' : {
            'id' : device,
            'mem_limit' : 10*1024*1024
        }
    })

    # Setup for the first ORTTRainer run
    torch.manual_seed(seed)
    set_seed(seed)
    model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _load_pytorch_transformer_model(device)
    first_trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts)
    data, targets = batcher_fn(train_data, 0)
    _ = first_trainer.train_step(data, targets)
    assert first_trainer._onnx_model is not None

    # Setup for the second ORTTRainer run
    torch.manual_seed(seed)
    set_seed(seed)
    model, _, _, _, _, _, _ = _load_pytorch_transformer_model(device)
    second_trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts)
    _ = second_trainer.train_step(data, targets)
    assert second_trainer._onnx_model is not None

    # Compare two different instances with identical setup
    assert id(first_trainer._onnx_model) != id(second_trainer._onnx_model)
    _test_helpers.assert_onnx_weights(first_trainer, second_trainer)
def testToyBERTSaveAsONNX():
    device = 'cuda'
    onnx_file_name = '_____temp_toy_bert_onnx_model.onnx'
    if os.path.exists(onnx_file_name):
        os.remove(onnx_file_name)
    assert not os.path.exists(onnx_file_name)

    # Load trainer
    model_desc = bert_model_description()
    model = load_bert_onnx_model()

    optim_config = optim.LambConfig()
    opts = orttrainer.ORTTrainerOptions({
        'debug': {
            'deterministic_compute': True
        },
        'device': {
            'id': device,
        },
    })

    trainer = orttrainer.ORTTrainer(model,
                                    model_desc,
                                    optim_config,
                                    options=opts)

    trainer.save_as_onnx(onnx_file_name)
    assert os.path.exists(onnx_file_name)

    with open(onnx_file_name, "rb") as f:
        bin_str = f.read()
        reload_onnx_model = onnx.load_model_from_string(bin_str)
    os.remove(onnx_file_name)

    # Create a new trainer from persisted ONNX model and compare with original ONNX model
    trainer_from_onnx = orttrainer.ORTTrainer(reload_onnx_model,
                                              model_desc,
                                              optim_config,
                                              options=opts)
    assert trainer_from_onnx._onnx_model is not None
    assert (id(trainer_from_onnx._onnx_model) != id(trainer._onnx_model))
    for initializer, loaded_initializer in zip(
            trainer._onnx_model.graph.initializer,
            trainer_from_onnx._onnx_model.graph.initializer):
        assert initializer.name == loaded_initializer.name
    assert (onnx.helper.printable_graph(
        trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph(
            trainer._onnx_model.graph))
    _test_helpers.assert_onnx_weights(trainer, trainer_from_onnx)