示例#1
0
def prefetch_pytorch_loader(args, train=True, pin_memory=True):
    from MixedPrecision.tools.prefetcher import DataPreFetcher
    from MixedPrecision.tools.stats import StatStream
    import MixedPrecision.tools.utils as utils

    data_transforms = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
    ])

    train_dataset = TimedImageFolder(args.data, data_transforms)

    loader = torch.utils.data.DataLoader(train_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=args.workers,
                                         pin_memory=pin_memory,
                                         collate_fn=utils.timed_fast_collate)

    mean = utils.enable_half(
        torch.tensor([0.485 * 255, 0.456 * 255,
                      0.406 * 255]).float()).view(1, 3, 1, 1)
    std = utils.enable_half(
        torch.tensor([0.229 * 255, 0.224 * 255,
                      0.225 * 255]).float()).view(1, 3, 1, 1)

    return DataPreFetcher(loader,
                          mean=mean,
                          std=std,
                          cpu_stats=StatStream(drop_first_obs=10),
                          gpu_stats=StatStream(drop_first_obs=10))
def train(args, model, data):
    import time

    import MixedPrecision.tools.utils as utils
    from MixedPrecision.tools.optimizer import OptimizerAdapter
    from MixedPrecision.tools.stats import StatStream

    model = utils.enable_cuda(model)
    model = utils.enable_half(model)

    criterion = utils.enable_cuda(nn.CrossEntropyLoss())
    criterion = utils.enable_half(criterion)

    optimizer = optim.SGD(
        model.parameters(),
        lr=args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )
    optimizer = OptimizerAdapter(
        optimizer,
        static_loss_scale=args.static_loss_scale,
        dynamic_loss_scale=args.dynamic_loss_scale
    )

    model.train()

    compute_time = StatStream(1)
    floss = float('inf')

    for epoch in range(0, args.epochs):
        cstart = time.time()

        for batch in data:
            x, y = batch

            x = utils.enable_cuda(x)
            y = utils.enable_cuda(y)

            x = utils.enable_half(x)
            out = model(x)

            loss = criterion(out, y)

            floss = loss.item()

            optimizer.zero_grad()
            optimizer.backward(loss)
            optimizer.step()

        cend = time.time()
        compute_time += cend - cstart

        print('[{:4d}] Compute Time (avg: {:.4f}, sd: {:.4f}) Loss: {:.4f}'.format(
            1 + epoch, compute_time.avg, compute_time.sd, floss))
def load_mnist(args, fake_data=False, hwc_permute=False, shape=(1, 28, 28)):
    import MixedPrecision.tools.utils as utils

    perm = transforms.Lambda(lambda x: x)

    if hwc_permute:
        perm = transforms.Lambda(lambda x: x.permute(1, 2, 0))

    trans = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
        perm,
        transforms.Lambda(lambda x: utils.enable_half(utils.enable_cuda(x)))
    ])

    dataset = None

    if fake_data:
        dataset = fakeit('pytorch', args.batch_size * 10, shape, 10, trans)
    else:
        dataset = datasets.MNIST(args.data + '/', train=True, download=True, transform=trans)

    train_loader = torch.utils.data.DataLoader(
        dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)

    return train_loader
示例#4
0
def main():
    import sys
    from MixedPrecision.pytorch.mnist_fully_connected import load_mnist
    from MixedPrecision.pytorch.mnist_fully_connected import train
    from MixedPrecision.pytorch.mnist_fully_connected import init_weights
    from MixedPrecision.tools.args import get_parser
    from MixedPrecision.tools.utils import summary
    import MixedPrecision.tools.utils as utils

    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)

    parser = get_parser()
    args = parser.parse_args()

    utils.set_use_gpu(args.gpu)
    utils.set_use_half(args.half)

    shape = (1, 28, 28)
    if args.fake:
        shape = args.shape

    for k, v in vars(args).items():
        print('{:>30}: {}'.format(k, v))

    try:
        current_device = torch.cuda.current_device()
        print('{:>30}: {}'.format('GPU Count', torch.cuda.device_count()))
        print('{:>30}: {}'.format('GPU Name',
                                  torch.cuda.get_device_name(current_device)))
    except:
        pass

    model = MnistConvolution(input_shape=shape,
                             conv_num=args.conv_num,
                             kernel_size=args.kernel_size,
                             explicit_permute=args.permute)

    model.float()
    model.apply(init_weights)
    model = utils.enable_cuda(model)
    summary(model, input_size=(shape[0], shape[1], shape[2]))
    model = utils.enable_half(model)

    train(
        args, model,
        load_mnist(args,
                   hwc_permute=args.permute,
                   fake_data=args.fake,
                   shape=shape))

    sys.exit(0)