Example #1
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument(
        '--train-steps',
        type=int,
        default=-1,
        metavar='N',
        help=
        'number of steps to train. Set -1 to run through whole dataset (default: -1)'
    )
    parser.add_argument('--lr',
                        type=float,
                        default=0.01,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--batch-size',
                        type=int,
                        default=32,
                        metavar='N',
                        help='input batch size for training (default: 32)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for testing (default: 64)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        metavar='S',
                        help='random seed (default: 42)')
    parser.add_argument('--pytorch-only',
                        action='store_true',
                        default=False,
                        help='disables ONNX Runtime training')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=300,
        metavar='N',
        help=
        'how many batches to wait before logging training status (default: 300)'
    )
    parser.add_argument('--view-graphs',
                        action='store_true',
                        default=False,
                        help='views forward and backward graphs')
    parser.add_argument('--export-onnx-graphs',
                        action='store_true',
                        default=False,
                        help='export ONNX graphs to current directory')
    parser.add_argument('--epochs',
                        type=int,
                        default=5,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument(
        '--log-level',
        choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
        default='WARNING',
        help='Log level (default: WARNING)')
    parser.add_argument('--data-dir',
                        type=str,
                        default='./mnist',
                        help='Path to the mnist data directory')

    args = parser.parse_args()

    # Common setup
    torch.manual_seed(args.seed)
    onnxruntime.set_seed(args.seed)

    if not args.no_cuda and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"

    ## Data loader
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        args.data_dir,
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True)
    test_loader = None
    if args.test_batch_size > 0:
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST(args.data_dir,
                           train=False,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307, ), (0.3081, ))
                           ])),
            batch_size=args.test_batch_size,
            shuffle=True)

    # Model architecture
    model = NeuralNet(input_size=784, hidden_size=500,
                      num_classes=10).to(device)
    if not args.pytorch_only:
        print('Training MNIST on ORTModule....')

        # Just for future debugging
        debug_options = DebugOptions(save_onnx=args.export_onnx_graphs,
                                     onnx_prefix='MNIST')

        model = ORTModule(model, debug_options)

        # Set log level
        numeric_level = getattr(logging, args.log_level.upper(), None)
        if not isinstance(numeric_level, int):
            raise ValueError('Invalid log level: %s' % args.log_level)
        logging.basicConfig(level=numeric_level)
    else:
        print('Training MNIST on vanilla PyTorch....')
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)

    # Train loop
    total_training_time, total_test_time, epoch_0_training, validation_accuracy = 0, 0, 0, 0
    for epoch in range(0, args.epochs):
        total_training_time += train(args, model, device, optimizer, my_loss,
                                     train_loader, epoch)
        if not args.pytorch_only and epoch == 0:
            epoch_0_training = total_training_time
        if args.test_batch_size > 0:
            test_time, validation_accuracy = test(args, model, device, my_loss,
                                                  test_loader)
            total_test_time += test_time

    assert validation_accuracy > 0.92

    print('\n======== Global stats ========')
    if not args.pytorch_only:
        estimated_export = 0
        if args.epochs > 1:
            estimated_export = epoch_0_training - (
                total_training_time - epoch_0_training) / (args.epochs - 1)
            print("  Estimated ONNX export took:               {:.4f}s".format(
                estimated_export))
        else:
            print(
                "  Estimated ONNX export took:               Estimate available when epochs > 1 only"
            )
        print("  Accumulated training without export took: {:.4f}s".format(
            total_training_time - estimated_export))
    print("  Accumulated training took:                {:.4f}s".format(
        total_training_time))
    print("  Accumulated validation took:              {:.4f}s".format(
        total_test_time))
def train(rank: int, args, world_size: int, epochs: int):

    # DDP init example
    dist_init(rank, world_size)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Setup
    if not args.cpu:
        torch.cuda.set_device(rank)
        torch.cuda.manual_seed(0)
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    # Problem statement
    model = NeuralNet(input_size=784, hidden_size=500, num_classes=10).to(rank)

    if args.use_ortmodule:
        print("Converting to ORTModule....")
        debug_options = DebugOptions(save_onnx=args.export_onnx_graphs,
                                     onnx_prefix="NeuralNet")

        model = ORTModule(model, debug_options)

    train_dataloader, test_dataloader = get_dataloader(args, rank,
                                                       args.batch_size)
    loss_fn = my_loss
    base_optimizer = torch.optim.SGD  # pick any pytorch compliant optimizer here
    base_optimizer_arguments = (
        {}
    )  # pass any optimizer specific arguments here, or directly below when instantiating OSS
    if args.use_sharded_optimizer:
        # Wrap the optimizer in its state sharding brethren
        optimizer = OSS(params=model.parameters(),
                        optim=base_optimizer,
                        lr=args.lr)

        # Wrap the model into ShardedDDP, which will reduce gradients to the proper ranks
        model = ShardedDDP(model, optimizer)
    else:
        device_ids = None if args.cpu else [rank]
        model = DDP(model, device_ids=device_ids,
                    find_unused_parameters=False)  # type: ignore

        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
    # Any relevant training loop, nothing specific to OSS. For example:
    model.train()
    total_training_time, total_test_time, epoch_0_training, validation_accuracy = 0, 0, 0, 0
    for epoch in range(epochs):
        total_training_time += train_step(args, model, rank, optimizer,
                                          loss_fn, train_dataloader, epoch)
        if epoch == 0:
            epoch_0_training = total_training_time
        if args.test_batch_size > 0:
            test_time, validation_accuracy = test(args, model, rank, loss_fn,
                                                  test_dataloader)
            total_test_time += test_time

    print("\n======== Global stats ========")
    if args.use_ortmodule:
        estimated_export = 0
        if args.epochs > 1:
            estimated_export = epoch_0_training - (
                total_training_time - epoch_0_training) / (args.epochs - 1)
            print("  Estimated ONNX export took:               {:.4f}s".format(
                estimated_export))
        else:
            print(
                "  Estimated ONNX export took:               Estimate available when epochs > 1 only"
            )
        print("  Accumulated training without export took: {:.4f}s".format(
            total_training_time - estimated_export))
    print("  Accumulated training took:                {:.4f}s".format(
        total_training_time))
    print("  Accumulated validation took:              {:.4f}s".format(
        total_test_time))

    dist.destroy_process_group()
def main():
    # 1. Basic setup
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--pytorch-only',
                        action='store_true',
                        default=False,
                        help='disables ONNX Runtime training')
    parser.add_argument('--batch-size',
                        type=int,
                        default=32,
                        metavar='N',
                        help='input batch size for training (default: 32)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for testing (default: 64)')
    parser.add_argument('--view-graphs',
                        action='store_true',
                        default=False,
                        help='views forward and backward graphs')
    parser.add_argument('--export-onnx-graphs',
                        action='store_true',
                        default=False,
                        help='export ONNX graphs to current directory')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--epochs',
                        type=int,
                        default=4,
                        metavar='N',
                        help='number of epochs to train (default: 4)')
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        metavar='S',
                        help='random seed (default: 42)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=40,
        metavar='N',
        help=
        'how many batches to wait before logging training status (default: 40)'
    )
    parser.add_argument(
        '--train-steps',
        type=int,
        default=-1,
        metavar='N',
        help=
        'number of steps to train. Set -1 to run through whole dataset (default: -1)'
    )
    parser.add_argument(
        '--log-level',
        choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
        default='WARNING',
        help='Log level (default: WARNING)')
    parser.add_argument(
        '--num-hidden-layers',
        type=int,
        default=1,
        metavar='H',
        help=
        'Number of hidden layers for the BERT model. A vanila BERT has 12 hidden layers (default: 1)'
    )
    parser.add_argument('--data-dir',
                        type=str,
                        default='./cola_public/raw',
                        help='Path to the bert data directory')

    args = parser.parse_args()

    # Device (CPU vs CUDA)
    if torch.cuda.is_available() and not args.no_cuda:
        device = torch.device("cuda")
        print('There are %d GPU(s) available.' % torch.cuda.device_count())
        print('We will use the GPU:', torch.cuda.get_device_name(0))
    else:
        print('No GPU available, using the CPU instead.')
        device = torch.device("cpu")

    # Set log level
    numeric_level = getattr(logging, args.log_level.upper(), None)
    if not isinstance(numeric_level, int):
        raise ValueError('Invalid log level: %s' % args.log_level)
    logging.basicConfig(level=numeric_level)

    # 2. Dataloader
    train_dataloader, validation_dataloader = load_dataset(args)

    # 3. Modeling
    # Load BertForSequenceClassification, the pretrained BERT model with a single
    # linear classification layer on top.
    config = AutoConfig.from_pretrained(
        "bert-base-uncased",
        num_labels=2,
        num_hidden_layers=args.num_hidden_layers,
        output_attentions=False,  # Whether the model returns attentions weights.
        output_hidden_states=
        False,  # Whether the model returns all hidden-states.
    )
    model = BertForSequenceClassification.from_pretrained(
        "bert-base-uncased",  # Use the 12-layer BERT model, with an uncased vocab.
        config=config,
    )

    if not args.pytorch_only:
        # Just for future debugging
        debug_options = DebugOptions(
            save_onnx=args.export_onnx_graphs,
            onnx_prefix='BertForSequenceClassification')

        model = ORTModule(model, debug_options)

    # Tell pytorch to run this model on the GPU.
    if torch.cuda.is_available() and not args.no_cuda:
        model.cuda()

    # Note: AdamW is a class from the huggingface library (as opposed to pytorch)
    optimizer = AdamW(
        model.parameters(),
        lr=2e-5,  # args.learning_rate - default is 5e-5, our notebook had 2e-5
        eps=1e-8  # args.adam_epsilon  - default is 1e-8.
    )

    # Authors recommend between 2 and 4 epochs
    # Total number of training steps is number of batches * number of epochs.
    total_steps = len(train_dataloader) * args.epochs

    # Create the learning rate scheduler.
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,  # Default value in run_glue.py
        num_training_steps=total_steps)
    # Seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    onnxruntime.set_seed(args.seed)
    if torch.cuda.is_available() and not args.no_cuda:
        torch.cuda.manual_seed_all(args.seed)

    # 4. Train loop (fine-tune)
    total_training_time, total_test_time, epoch_0_training, validation_accuracy = 0, 0, 0, 0
    for epoch_i in range(0, args.epochs):
        total_training_time += train(model, optimizer, scheduler,
                                     train_dataloader, epoch_i, device, args)
        if not args.pytorch_only and epoch_i == 0:
            epoch_0_training = total_training_time
        test_time, validation_accuracy = test(model, validation_dataloader,
                                              device, args)
        total_test_time += test_time

    assert validation_accuracy > 0.5

    print('\n======== Global stats ========')
    if not args.pytorch_only:
        estimated_export = 0
        if args.epochs > 1:
            estimated_export = epoch_0_training - (
                total_training_time - epoch_0_training) / (args.epochs - 1)
            print("  Estimated ONNX export took:               {:.4f}s".format(
                estimated_export))
        else:
            print(
                "  Estimated ONNX export took:               Estimate available when epochs > 1 only"
            )
        print("  Accumulated training without export took: {:.4f}s".format(
            total_training_time - estimated_export))
    print("  Accumulated training took:                {:.4f}s".format(
        total_training_time))
    print("  Accumulated validation took:              {:.4f}s".format(
        total_test_time))
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--train-steps', type=int, default=-1, metavar='N',
                        help='number of steps to train. Set -1 to run through whole dataset (default: -1)')
    parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--batch-size', type=int, default=32, metavar='N',
                        help='input batch size for training (default: 32)')
    parser.add_argument('--test-batch-size', type=int, default=64, metavar='N',
                        help='input batch size for testing (default: 64)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=42, metavar='S',
                        help='random seed (default: 42)')
    parser.add_argument('--pytorch-only', action='store_true', default=False,
                        help='disables ONNX Runtime training')
    parser.add_argument('--log-interval', type=int, default=300, metavar='N',
                        help='how many batches to wait before logging training status (default: 300)')
    parser.add_argument('--view-graphs', action='store_true', default=False,
                        help='views forward and backward graphs')
    parser.add_argument('--epochs', type=int, default=10, metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--log-level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], default='WARNING',
                        help='Log level (default: WARNING)')
    parser.add_argument('--data-dir', type=str, default='./mnist',
                        help='Path to the mnist data directory')

    # DeepSpeed-related settings
    parser.add_argument('--local_rank',
                        type=int,
                        required=True,
                        help='local rank passed from distributed launcher')
    parser = deepspeed.add_config_arguments(parser)
    
    args = parser.parse_args()


    # Common setup
    torch.manual_seed(args.seed)
    onnxruntime.set_seed(args.seed)

    # TODO: CUDA support is broken due to copying from PyTorch into ORT
    if not args.no_cuda and torch.cuda.is_available():
        device = "cuda:" + str(args.local_rank)
    else:
        device = "cpu"

    ## Data loader

    dist.init_process_group(backend='nccl')
    if args.local_rank == 0:
        # download only once on rank 0
        datasets.MNIST(args.data_dir, download=True)
    dist.barrier()
    train_set = datasets.MNIST(args.data_dir, train=True,
                            transform=transforms.Compose([transforms.ToTensor(),
                                                        transforms.Normalize((0.1307,), (0.3081,))]))

    test_loader = None
    if args.test_batch_size > 0:
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST(args.data_dir, train=False, transform=transforms.Compose([
                transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])),
            batch_size=args.test_batch_size, shuffle=True)

    # Model architecture
    model = NeuralNet(input_size=784, hidden_size=500, num_classes=10).to(device)
    if not args.pytorch_only:
        print('Training MNIST on ORTModule....')

        # Set log level
        log_level_mapping = {"DEBUG": LogLevel.VERBOSE,
                            "INFO": LogLevel.INFO,
                            "WARNING": LogLevel.WARNING,
                            "ERROR": LogLevel.ERROR,
                            "CRITICAL": LogLevel.FATAL}
        log_level = log_level_mapping.get(args.log_level.upper(), None)
        if not isinstance(log_level, LogLevel):
            raise ValueError('Invalid log level: %s' % args.log_level)
        debug_options = DebugOptions(log_level=log_level, save_onnx=False, onnx_prefix='MNIST')

        model = ORTModule(model, debug_options)

    else:
        print('Training MNIST on vanilla PyTorch....')

    model, optimizer, train_loader, _ = deepspeed.initialize(
        args=args,model=model,
        model_parameters=[p for p in model.parameters() if p.requires_grad],
        training_data=train_set)
    
    # Train loop
    total_training_time, total_test_time, epoch_0_training = 0, 0, 0
    for epoch in range(0, args.epochs):
        total_training_time += train(args, model, device, optimizer, my_loss, train_loader, epoch)
        if not args.pytorch_only and epoch == 0:
            epoch_0_training = total_training_time
        if args.test_batch_size > 0:
            total_test_time += test(args, model, device, my_loss, test_loader)

    print('\n======== Global stats ========')
    if not args.pytorch_only:
        estimated_export = 0
        if args.epochs > 1:
            estimated_export = epoch_0_training - (total_training_time - epoch_0_training)/(args.epochs-1)
            print("  Estimated ONNX export took:               {:.4f}s".format(estimated_export))
        else:
            print("  Estimated ONNX export took:               Estimate available when epochs > 1 only")
        print("  Accumulated training without export took: {:.4f}s".format(total_training_time - estimated_export))
    print("  Accumulated training took:                {:.4f}s".format(total_training_time))
    print("  Accumulated validation took:              {:.4f}s".format(total_test_time))