Exemplo n.º 1
0
def main_worker(args):
    global best_acc1

    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.topology))
        model = models.__dict__[args.topology](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.topology))
        model = models.__dict__[args.topology]()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if args.prune:
        from lpot.experimental import Pruning, common
        prune = Pruning(args.config)

        prune.model = common.Model(model)
        model = prune()
        model.save(args.output_model)
        return
Exemplo n.º 2
0
    def test_pruning(self):
        from lpot.experimental import Pruning, common
        prune = Pruning('fake.yaml')

        dummy_dataset = PyTorchDummyDataset([tuple([100, 3, 256, 256])])
        dummy_dataloader = PyTorchDataLoader(dummy_dataset)

        def training_func_for_lpot(model):
            epochs = 16
            iters = 30
            criterion = nn.CrossEntropyLoss()
            optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
            for nepoch in range(epochs):
                model.train()
                cnt = 0
                prune.on_epoch_begin(nepoch)
                for image, target in dummy_dataloader:
                    prune.on_batch_begin(cnt)
                    print('.', end='')
                    cnt += 1
                    output = model(image)
                    loss = criterion(output, target)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    prune.on_batch_end()
                    if cnt >= iters:
                        break
                prune.on_epoch_end()

        dummy_dataset = PyTorchDummyDataset(tuple([100, 3, 256, 256]),
                                            label=True)
        dummy_dataloader = PyTorchDataLoader(dummy_dataset)
        prune.model = common.Model(self.model)
        prune.q_func = training_func_for_lpot
        prune.eval_dataloader = dummy_dataloader
        _ = prune()
Exemplo n.º 3
0
def train(args, train_dataset, model, tokenizer):
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size,
                                  collate_fn=collate_fn)
    def train_func(model):
        return take_train_steps(args, model, tokenizer, train_dataloader, prune)
    
    def eval_func(model):
        return take_eval_steps(args, model, tokenizer, prune)

    if args.prune:
        from lpot.experimental import Pruning, common
        prune = Pruning(args.config)
        prune.model = common.Model(model)
        prune.train_dataloader = train_dataloader
        prune.pruning_func = train_func
        prune.eval_dataloader = train_dataloader
        prune.eval_func = eval_func
        model = prune()
        torch.save(model, args.output_model)
Exemplo n.º 4
0
    def test_pruning_external(self):
        from lpot.experimental import common
        from lpot import Pruning
        prune = Pruning('fake.yaml')
        datasets = DATASETS('pytorch')
        dummy_dataset = datasets['dummy'](shape=(100, 3, 224, 224),
                                          low=0.,
                                          high=1.,
                                          label=True)
        dummy_dataloader = PyTorchDataLoader(dummy_dataset)

        def training_func_for_lpot(model):
            epochs = 16
            iters = 30
            criterion = nn.CrossEntropyLoss()
            optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
            for nepoch in range(epochs):
                model.train()
                cnt = 0
                prune.on_epoch_begin(nepoch)
                for image, target in dummy_dataloader:
                    prune.on_batch_begin(cnt)
                    print('.', end='')
                    cnt += 1
                    output = model(image)
                    loss = criterion(output, target)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    prune.on_batch_end()
                    if cnt >= iters:
                        break
                prune.on_epoch_end()

        prune.model = common.Model(self.model)
        prune.pruning_func = training_func_for_lpot
        prune.eval_dataloader = dummy_dataloader
        prune.train_dataloader = dummy_dataloader
        _ = prune(common.Model(self.model), \
                  train_dataloader=dummy_dataloader, \
                  pruning_func=training_func_for_lpot, \
                  eval_dataloader=dummy_dataloader)
Exemplo n.º 5
0
def main_worker(gpu, args):
    global best_acc1
    print("Use CPU: {} for training".format(gpu))

    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True, quantize=False)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion, args)
        return

    if args.prune:
        from lpot.experimental import Pruning, common
        prune = Pruning(args.config)

        def training_func_for_lpot(model):
            epochs = 16
            iters = 30
            optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
            for nepoch in range(epochs):
                model.train()
                cnt = 0
                prune.on_epoch_begin(nepoch)
                for image, target in train_loader:
                    prune.on_batch_begin(cnt)
                    print('.', end='')
                    cnt += 1
                    output = model(image)
                    loss = criterion(output, target)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    prune.on_batch_end()
                    if cnt >= iters:
                        break
                prune.on_epoch_end()
                if nepoch > 3:
                    # Freeze quantizer parameters
                    model.apply(torch.quantization.disable_observer)
                if nepoch > 2:
                    # Freeze batch norm mean and variance estimates
                    model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
            validate(val_loader, model, criterion, args)

            return

        prune.model = common.Model(model)
        prune.q_func = training_func_for_lpot
        q_model = prune()
        return