Beispiel #1
0
 def __init__(self, lr, use_ortmodule=True):
     super().__init__()
     self.lr = lr
     self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(),
                                  nn.Linear(64, 3))
     self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(),
                                  nn.Linear(64, 28 * 28))
     if use_ortmodule:
         self.encoder = ORTModule(self.encoder)
         self.decoder = ORTModule(self.decoder)
 def test_half_type(self):
     model = NoOpNet()
     device = torch.device("ort")
     model.to(device)
     model = ORTModule(model)
     input = torch.ones(2, 2).to(torch.float16)
     y = model(input.to(device))
     assert y.dtype == torch.float16
    def test_ortmodule_inference(self):
        input_size = 784
        hidden_size = 500
        num_classes = 10
        batch_size = 128
        model = NeuralNet(input_size, hidden_size, num_classes)
        device = torch.device("ort")
        model.to(device)
        model = ORTModule(model)

        with torch.no_grad():
            data = torch.rand(batch_size, input_size)
            y = model(data.to(device))
        print("Done")
    def test_ort_module_and_eager_mode(self):
        input_size = 784
        hidden_size = 500
        num_classes = 10
        batch_size = 128
        model = NeuralNet(input_size, hidden_size, num_classes)
        optimizer = optim.SGD(model.parameters(), lr=0.01)

        data = torch.rand(batch_size, input_size)
        target = torch.randint(0, 10, (batch_size, ))
        # save the initial state
        initial_state = model.state_dict()
        # run on cpu first
        x = model(data)
        loss = my_loss(x, target)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # record the updated parameters
        cpu_updated_state = model.state_dict()
        # reload initial state
        model.load_state_dict(initial_state)
        # run on ort with ORTModule and eager mode
        # use device_idx 1 to test non-zero device
        torch_ort_eager.set_device(1, "CPUExecutionProvider",
                                   {"dummy": "dummy"})
        device = torch.device("ort", index=0)
        model.to(device)
        model = ORTModule(model)
        ort_optimizer = optim.SGD(model.parameters(), lr=0.01)
        x = model(data.to(device))
        loss = my_loss(x.cpu(), target)
        loss.backward()
        ort_optimizer.step()
        ort_optimizer.zero_grad()

        ort_updated_state = model.state_dict()
        # compare the updated state
        for state_tensor in cpu_updated_state:
            assert state_tensor in ort_updated_state
            assert torch.allclose(cpu_updated_state[state_tensor],
                                  ort_updated_state[state_tensor].cpu(),
                                  atol=1e-3)
Beispiel #5
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument(
        '--train-steps',
        type=int,
        default=-1,
        metavar='N',
        help=
        'number of steps to train. Set -1 to run through whole dataset (default: -1)'
    )
    parser.add_argument('--lr',
                        type=float,
                        default=0.01,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--batch-size',
                        type=int,
                        default=32,
                        metavar='N',
                        help='input batch size for training (default: 32)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for testing (default: 64)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        metavar='S',
                        help='random seed (default: 42)')
    parser.add_argument('--pytorch-only',
                        action='store_true',
                        default=False,
                        help='disables ONNX Runtime training')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=300,
        metavar='N',
        help=
        'how many batches to wait before logging training status (default: 300)'
    )
    parser.add_argument('--view-graphs',
                        action='store_true',
                        default=False,
                        help='views forward and backward graphs')
    parser.add_argument('--epochs',
                        type=int,
                        default=5,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument(
        '--log-level',
        choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
        default='WARNING',
        help='Log level (default: WARNING)')
    parser.add_argument('--data-dir',
                        type=str,
                        default='./mnist',
                        help='Path to the mnist data directory')

    args = parser.parse_args()

    # Common setup
    torch.manual_seed(args.seed)
    onnxruntime.set_seed(args.seed)

    if not args.no_cuda and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"

    ## Data loader
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        args.data_dir,
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True)
    test_loader = None
    if args.test_batch_size > 0:
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST(args.data_dir,
                           train=False,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307, ), (0.3081, ))
                           ])),
            batch_size=args.test_batch_size,
            shuffle=True)

    # Model architecture
    model = NeuralNet(input_size=784, hidden_size=500,
                      num_classes=10).to(device)
    if not args.pytorch_only:
        print('Training MNIST on ORTModule....')
        model = ORTModule(model)

        # TODO: change it to False to stop saving ONNX models
        model._save_onnx = True
        model._save_onnx_prefix = 'MNIST'

        # Set log level
        numeric_level = getattr(logging, args.log_level.upper(), None)
        if not isinstance(numeric_level, int):
            raise ValueError('Invalid log level: %s' % args.log_level)
        logging.basicConfig(level=numeric_level)
    else:
        print('Training MNIST on vanilla PyTorch....')
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)

    # Train loop
    total_training_time, total_test_time, epoch_0_training, validation_accuracy = 0, 0, 0, 0
    for epoch in range(0, args.epochs):
        total_training_time += train(args, model, device, optimizer, my_loss,
                                     train_loader, epoch)
        if not args.pytorch_only and epoch == 0:
            epoch_0_training = total_training_time
        if args.test_batch_size > 0:
            test_time, validation_accuracy = test(args, model, device, my_loss,
                                                  test_loader)
            total_test_time += test_time

    assert validation_accuracy > 0.92

    print('\n======== Global stats ========')
    if not args.pytorch_only:
        estimated_export = 0
        if args.epochs > 1:
            estimated_export = epoch_0_training - (
                total_training_time - epoch_0_training) / (args.epochs - 1)
            print("  Estimated ONNX export took:               {:.4f}s".format(
                estimated_export))
        else:
            print(
                "  Estimated ONNX export took:               Estimate available when epochs > 1 only"
            )
        print("  Accumulated training without export took: {:.4f}s".format(
            total_training_time - estimated_export))
    print("  Accumulated training took:                {:.4f}s".format(
        total_training_time))
    print("  Accumulated validation took:              {:.4f}s".format(
        total_test_time))
Beispiel #6
0
def train(rank: int, args, world_size: int, epochs: int):

    # DDP init example
    dist_init(rank, world_size)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Setup
    if not args.cpu:
        torch.cuda.set_device(rank)
        torch.cuda.manual_seed(0)
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    # Problem statement
    model = NeuralNet(input_size=784, hidden_size=500, num_classes=10).to(rank)

    if args.use_ortmodule:
        print("Converting to ORTModule....")
        model = ORTModule(model)

    train_dataloader, test_dataloader = get_dataloader(args, rank,
                                                       args.batch_size)
    loss_fn = my_loss
    base_optimizer = torch.optim.SGD  # pick any pytorch compliant optimizer here
    base_optimizer_arguments = {
    }  # pass any optimizer specific arguments here, or directly below when instantiating OSS
    if args.use_sharded_optimizer:
        # Wrap the optimizer in its state sharding brethren
        optimizer = OSS(params=model.parameters(),
                        optim=base_optimizer,
                        lr=args.lr)

        # Wrap the model into ShardedDDP, which will reduce gradients to the proper ranks
        model = ShardedDDP(model, optimizer)
    else:
        device_ids = None if args.cpu else [rank]
        model = DDP(model, device_ids=device_ids,
                    find_unused_parameters=False)  # type: ignore

        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
    # Any relevant training loop, nothing specific to OSS. For example:
    model.train()
    total_training_time, total_test_time, epoch_0_training, validation_accuracy = 0, 0, 0, 0
    for epoch in range(epochs):
        total_training_time += train_step(args, model, rank, optimizer,
                                          loss_fn, train_dataloader, epoch)
        if epoch == 0:
            epoch_0_training = total_training_time
        if args.test_batch_size > 0:
            test_time, validation_accuracy = test(args, model, rank, loss_fn,
                                                  test_dataloader)
            total_test_time += test_time

    print('\n======== Global stats ========')
    if args.use_ortmodule:
        estimated_export = 0
        if args.epochs > 1:
            estimated_export = epoch_0_training - (
                total_training_time - epoch_0_training) / (args.epochs - 1)
            print("  Estimated ONNX export took:               {:.4f}s".format(
                estimated_export))
        else:
            print(
                "  Estimated ONNX export took:               Estimate available when epochs > 1 only"
            )
        print("  Accumulated training without export took: {:.4f}s".format(
            total_training_time - estimated_export))
    print("  Accumulated training took:                {:.4f}s".format(
        total_training_time))
    print("  Accumulated validation took:              {:.4f}s".format(
        total_test_time))

    dist.destroy_process_group()
def benchmark(N=1000,
              n_features=20,
              hidden_layer_sizes="26,25",
              max_iter=1000,
              learning_rate_init=1e-4,
              batch_size=100,
              run_torch=True,
              device='cpu',
              opset=12,
              profile='fct'):
    """
    Compares :epkg:`onnxruntime-training` to :epkg:`scikit-learn` for
    training. Training algorithm is SGD.

    :param N: number of observations to train on
    :param n_features: number of features
    :param hidden_layer_sizes: hidden layer sizes, comma separated values
    :param max_iter: number of iterations
    :param learning_rate_init: initial learning rate
    :param batch_size: batch size
    :param run_torch: train scikit-learn in the same condition (True) or
        just walk through one iterator with *scikit-learn*
    :param device: `'cpu'` or `'cuda'`
    :param opset: opset to choose for the conversion
    :param profile: 'fct' to use cProfile, 'event' to use WithEventProfiler
    """
    N = int(N)
    n_features = int(n_features)
    max_iter = int(max_iter)
    learning_rate_init = float(learning_rate_init)
    batch_size = int(batch_size)
    run_torch = run_torch in (1, True, '1', 'True')

    print("N=%d" % N)
    print("n_features=%d" % n_features)
    print(f"hidden_layer_sizes={hidden_layer_sizes!r}")
    print("max_iter=%d" % max_iter)
    print(f"learning_rate_init={learning_rate_init:f}")
    print("batch_size=%d" % batch_size)
    print(f"run_torch={run_torch!r}")
    print(f"opset={opset!r} (unused)")
    print(f"device={device!r}")
    print(f"profile={profile!r}")
    device0 = device
    device = torch.device("cuda:0" if device in ('cuda', 'cuda:0',
                                                 'gpu') else "cpu")
    print(f"fixed device={device!r}")
    print('------------------')

    if not isinstance(hidden_layer_sizes, tuple):
        hidden_layer_sizes = tuple(map(int, hidden_layer_sizes.split(",")))
    X, y = make_regression(N, n_features=n_features, bias=2)
    X = X.astype(numpy.float32)
    y = y.astype(numpy.float32)
    X_train, X_test, y_train, y_test = train_test_split(X, y)

    class Net(torch.nn.Module):
        def __init__(self, n_features, hidden, n_output):
            super(Net, self).__init__()
            self.hidden = []

            size = n_features
            for i, hid in enumerate(hidden_layer_sizes):
                self.hidden.append(torch.nn.Linear(size, hid))
                size = hid
                setattr(self, "hid%d" % i, self.hidden[-1])
            self.hidden.append(torch.nn.Linear(size, n_output))
            setattr(self, "predict", self.hidden[-1])

        def forward(self, x):
            for hid in self.hidden:
                x = hid(x)
                x = F.relu(x)
            return x

    nn = Net(n_features, hidden_layer_sizes, 1)
    if device0 == 'cpu':
        nn.cpu()
    else:
        nn.cuda(device=device)
    print(
        f"n_parameters={len(list(nn.parameters()))}, n_layers={len(nn.hidden)}"
    )
    for i, p in enumerate(nn.parameters()):
        print("  p[%d].shape=%r" % (i, p.shape))

    optimizer = torch.optim.SGD(nn.parameters(), lr=learning_rate_init)
    criterion = torch.nn.MSELoss(size_average=False)
    batch_no = len(X_train) // batch_size

    # training
    inputs = torch.tensor(X_train[:1], requires_grad=True, device=device)
    nn(inputs)

    def train_torch():
        for epoch in range(max_iter):
            running_loss = 0.0
            x, y = shuffle(X_train, y_train)
            for i in range(batch_no):
                start = i * batch_size
                end = start + batch_size
                inputs = torch.tensor(x[start:end],
                                      requires_grad=True,
                                      device=device)
                labels = torch.tensor(y[start:end],
                                      requires_grad=True,
                                      device=device)

                def step_torch():
                    optimizer.zero_grad()
                    outputs = nn(inputs)
                    loss = criterion(outputs, torch.unsqueeze(labels, dim=1))
                    loss.backward()
                    optimizer.step()
                    return loss

                loss = step_torch()
                running_loss += loss.item()
        return running_loss

    begin = time.perf_counter()
    if run_torch:
        if profile in ('cProfile', 'fct'):
            from pyquickhelper.pycode.profiling import profile
            running_loss, prof, _ = profile(train_torch, return_results=True)
            dur_torch = time.perf_counter() - begin
            name = f"{device0}.{os.path.split(__file__)[-1]}.tch.prof"
            prof.dump_stats(name)
        elif profile == 'event':

            def clean_name(x):
                return "/".join(x.replace("\\", "/").split('/')[-3:])

            prof = WithEventProfiler(size=10000000, clean_file_name=clean_name)
            with prof:
                running_loss = train_torch()
            dur_torch = time.perf_counter() - begin
            df = prof.report
            name = f"{device0}.{os.path.split(__file__)[-1]}.tch.csv"
            df.to_csv(name, index=False)
        else:
            running_loss = train_torch()
            dur_torch = time.perf_counter() - begin
    else:
        dur_torch = time.perf_counter() - begin

    if run_torch:
        print(f"time_torch={dur_torch!r}, running_loss={running_loss!r}")
        running_loss0 = running_loss
    else:
        running_loss0 = -1

    # ORTModule
    nn = Net(n_features, hidden_layer_sizes, 1)
    if device0 == 'cpu':
        nn.cpu()
    else:
        nn.cuda(device=device)

    nn_ort = ORTModule(nn)
    optimizer = torch.optim.SGD(nn_ort.parameters(), lr=learning_rate_init)
    criterion = torch.nn.MSELoss(size_average=False)

    # exclude onnx conversion
    inputs = torch.tensor(X_train[:1], requires_grad=True, device=device)
    nn_ort(inputs)

    def train_ort():
        for epoch in range(max_iter):
            running_loss = 0.0
            x, y = shuffle(X_train, y_train)
            for i in range(batch_no):
                start = i * batch_size
                end = start + batch_size
                inputs = torch.tensor(x[start:end],
                                      requires_grad=True,
                                      device=device)
                labels = torch.tensor(y[start:end],
                                      requires_grad=True,
                                      device=device)

                def step_ort():
                    optimizer.zero_grad()
                    outputs = nn_ort(inputs)
                    loss = criterion(outputs, torch.unsqueeze(labels, dim=1))
                    loss.backward()
                    optimizer.step()
                    return loss

                loss = step_ort()
                running_loss += loss.item()
        return running_loss

    begin = time.perf_counter()
    if profile in ('cProfile', 'fct'):
        from pyquickhelper.pycode.profiling import profile
        running_loss, prof, _ = profile(train_ort, return_results=True)
        dur_ort = time.perf_counter() - begin
        name = f"{device0}.{os.path.split(__file__)[-1]}.ort.prof"
        prof.dump_stats(name)
    elif profile == 'event':

        def clean_name(x):
            return "/".join(x.replace("\\", "/").split('/')[-3:])

        prof = WithEventProfiler(size=10000000, clean_file_name=clean_name)
        with prof:
            running_loss = train_ort()
        dur_ort = time.perf_counter() - begin
        df = prof.report
        name = f"{device0}.{os.path.split(__file__)[-1]}.ort.csv"
        df.to_csv(name, index=False)
    else:
        running_loss = train_ort()
        dur_ort = time.perf_counter() - begin

    print(f"time_torch={dur_torch!r}, running_loss={running_loss0!r}")
    print(f"time_ort={dur_ort!r}, last_trained_error={running_loss!r}")
Beispiel #8
0
def main():
    # 1. Basic setup
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--pytorch-only',
                        action='store_true',
                        default=False,
                        help='disables ONNX Runtime training')
    parser.add_argument('--batch-size',
                        type=int,
                        default=32,
                        metavar='N',
                        help='input batch size for training (default: 32)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for testing (default: 64)')
    parser.add_argument('--view-graphs',
                        action='store_true',
                        default=False,
                        help='views forward and backward graphs')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--epochs',
                        type=int,
                        default=4,
                        metavar='N',
                        help='number of epochs to train (default: 4)')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        metavar='S',
                        help='random seed (default: 42)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=40,
        metavar='N',
        help=
        'how many batches to wait before logging training status (default: 40)'
    )
    parser.add_argument(
        '--train-steps',
        type=int,
        default=-1,
        metavar='N',
        help=
        'number of steps to train. Set -1 to run through whole dataset (default: -1)'
    )
    parser.add_argument(
        '--log-level',
        choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
        default='WARNING',
        help='Log level (default: WARNING)')
    parser.add_argument(
        '--num-hidden-layers',
        type=int,
        default=1,
        metavar='H',
        help=
        'Number of hidden layers for the BERT model. A vanila BERT has 12 hidden layers (default: 1)'
    )
    parser.add_argument('--data-dir',
                        type=str,
                        default='./cola_public/raw',
                        help='Path to the bert data directory')

    args = parser.parse_args()

    # Device (CPU vs CUDA)
    if torch.cuda.is_available() and not args.no_cuda:
        device = torch.device("cuda")
        print('There are %d GPU(s) available.' % torch.cuda.device_count())
        print('We will use the GPU:', torch.cuda.get_device_name(0))
    else:
        print('No GPU available, using the CPU instead.')
        device = torch.device("cpu")

    # Set log level
    numeric_level = getattr(logging, args.log_level.upper(), None)
    if not isinstance(numeric_level, int):
        raise ValueError('Invalid log level: %s' % args.log_level)
    logging.basicConfig(level=numeric_level)

    # 2. Dataloader
    train_dataloader, validation_dataloader = load_dataset(args)

    # 3. Modeling
    # Load BertForSequenceClassification, the pretrained BERT model with a single
    # linear classification layer on top.
    config = AutoConfig.from_pretrained(
        "bert-base-uncased",
        num_labels=2,
        num_hidden_layers=args.num_hidden_layers,
        output_attentions=False,  # Whether the model returns attentions weights.
        output_hidden_states=
        False,  # Whether the model returns all hidden-states.
    )
    model = BertForSequenceClassification.from_pretrained(
        "bert-base-uncased",  # Use the 12-layer BERT model, with an uncased vocab.
        config=config,
    )

    if not args.pytorch_only:
        model = ORTModule(model)

    # TODO: change it to False to stop saving ONNX models
    model._save_onnx = True
    model._save_onnx_prefix = 'BertForSequenceClassification'

    # Tell pytorch to run this model on the GPU.
    if torch.cuda.is_available() and not args.no_cuda:
        model.cuda()

    # Note: AdamW is a class from the huggingface library (as opposed to pytorch)
    optimizer = AdamW(
        model.parameters(),
        lr=2e-5,  # args.learning_rate - default is 5e-5, our notebook had 2e-5
        eps=1e-8  # args.adam_epsilon  - default is 1e-8.
    )

    # Authors recommend between 2 and 4 epochs
    # Total number of training steps is number of batches * number of epochs.
    total_steps = len(train_dataloader) * args.epochs

    # Create the learning rate scheduler.
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,  # Default value in run_glue.py
        num_training_steps=total_steps)
    # Seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    onnxruntime.set_seed(args.seed)
    if torch.cuda.is_available() and not args.no_cuda:
        torch.cuda.manual_seed_all(args.seed)

    # 4. Train loop (fine-tune)
    total_training_time, total_test_time, epoch_0_training, validation_accuracy = 0, 0, 0, 0
    for epoch_i in range(0, args.epochs):
        total_training_time += train(model, optimizer, scheduler,
                                     train_dataloader, epoch_i, device, args)
        if not args.pytorch_only and epoch_i == 0:
            epoch_0_training = total_training_time
        test_time, validation_accuracy = test(model, validation_dataloader,
                                              device, args)
        total_test_time += test_time

    assert validation_accuracy > 0.5

    print('\n======== Global stats ========')
    if not args.pytorch_only:
        estimated_export = 0
        if args.epochs > 1:
            estimated_export = epoch_0_training - (
                total_training_time - epoch_0_training) / (args.epochs - 1)
            print("  Estimated ONNX export took:               {:.4f}s".format(
                estimated_export))
        else:
            print(
                "  Estimated ONNX export took:               Estimate available when epochs > 1 only"
            )
        print("  Accumulated training without export took: {:.4f}s".format(
            total_training_time - estimated_export))
    print("  Accumulated training took:                {:.4f}s".format(
        total_training_time))
    print("  Accumulated validation took:              {:.4f}s".format(
        total_test_time))
# Model.

if args.run_without_ort:
    model = nn.Sequential(
        nn.Linear(d_in, d_hidden),  # Stage 1
        nn.ReLU(),  # Stage 1
        nn.Linear(d_hidden, d_hidden),  # Stage 1
        nn.ReLU(),  # Stage 1
        nn.Linear(d_hidden, d_hidden),  # Stage 2
        nn.ReLU(),  # Stage 2
        nn.Linear(d_hidden, d_out)  # Stage 2
    )

else:
    model = nn.Sequential(
        ORTModule(nn.Linear(d_in, d_hidden).to(device)),  # Stage 1
        nn.ReLU().to(
            device
        ),  # ORTModule(nn.ReLU().to(device)), Stage 1, TODO: ORTModule can wrap Relu once stateless model is supported.
        ORTModule(nn.Linear(d_hidden, d_hidden).to(device)),  # Stage 1
        nn.ReLU().to(
            device
        ),  # ORTModule(nn.ReLU().to(device)), Stage 1, TODO: ORTModule can wrap Relu once stateless model is supported.
        ORTModule(nn.Linear(d_hidden, d_hidden).to(device)),  # Stage 2
        nn.ReLU().to(
            device
        ),  # ORTModule(nn.ReLU().to(device)), Stage 2, TODO: ORTModule can wrap Relu once stateless model is supported.
        ORTModule(nn.Linear(d_hidden, d_out).to(device))  # Stage 2
    )

model = PipelineModule(
Beispiel #10
0
def main():
    #Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.01,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')

    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    kwargs = {'num_workers': 0, 'pin_memory': True}
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        './data',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        './data',
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              **kwargs)

    # set device
    torch_ort_eager.set_device(0, 'CPUExecutionProvider', {})

    device = torch.device('ort', index=0)
    input_size = 784
    hidden_size = 500
    num_classes = 10
    model = NeuralNet(input_size, hidden_size, num_classes)
    model.to(device)
    model = ORTModule(model)
    optimizer = optim.SGD(model.parameters(), lr=0.01)

    print('\nStart Training.')

    for epoch in range(1, args.epochs + 1):
        train_with_eager(args, model, optimizer, device, train_loader, epoch)
def main():
    # Training settings
    parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
    parser.add_argument("--batch-size",
                        type=int,
                        default=64,
                        metavar="N",
                        help="input batch size for training (default: 64)")
    parser.add_argument("--test-batch-size",
                        type=int,
                        default=1000,
                        metavar="N",
                        help="input batch size for testing (default: 1000)")
    parser.add_argument("--epochs",
                        type=int,
                        default=10,
                        metavar="N",
                        help="number of epochs to train (default: 10)")
    parser.add_argument("--lr",
                        type=float,
                        default=0.01,
                        metavar="LR",
                        help="learning rate (default: 0.01)")
    parser.add_argument("--no-cuda",
                        action="store_true",
                        default=False,
                        help="disables CUDA training")
    parser.add_argument("--seed",
                        type=int,
                        default=1,
                        metavar="S",
                        help="random seed (default: 1)")
    parser.add_argument(
        "--log-interval",
        type=int,
        default=10,
        metavar="N",
        help="how many batches to wait before logging training status",
    )

    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    kwargs = {"num_workers": 0, "pin_memory": True}
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            "./data",
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ]),
        ),
        batch_size=args.batch_size,
        shuffle=True,
        **kwargs,
    )
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            "./data",
            train=False,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ]),
        ),
        batch_size=args.test_batch_size,
        shuffle=True,
        **kwargs,
    )

    # set device
    torch_ort_eager.set_device(0, "CPUExecutionProvider", {})

    device = torch.device("ort", index=0)
    input_size = 784
    hidden_size = 500
    num_classes = 10
    model = NeuralNet(input_size, hidden_size, num_classes)
    model.to(device)
    model = ORTModule(model)
    optimizer = optim.SGD(model.parameters(), lr=0.01)

    print("\nStart Training.")

    for epoch in range(1, args.epochs + 1):
        train_with_eager(args, model, optimizer, device, train_loader, epoch)