Пример #1
0
def yolo_config(config, args):
    # if (args.prune_threshold > 0.005):
    #     print("WARNING: Prune threshold seems too large.")
    #     if input("Input y if you are sure you want to continue.") != 'y': return

    device = 'cpu' if args.no_cuda else 'cuda:0'
    model = config['model'](config['config_path'], device=device)
    wrapper = YoloWrapper(device, model)
    lr0 = 0.001
    # lr0 = args.lr
    optimizer = config['optimizer'](filter(lambda x: x.requires_grad,
                                           model.parameters()),
                                    lr=lr0,
                                    momentum=args.momentum)
    writer = SummaryWriter()

    print("Loading dataloaders..")
    train_dataloader = LoadImagesAndLabels(config['datasets']['train'],
                                           batch_size=args.batch_size,
                                           img_size=config['image_size'])
    val_dataloader = LoadImagesAndLabels(config['datasets']['test'],
                                         batch_size=args.batch_size,
                                         img_size=config['image_size'])

    if (args.pretrained_weights):
        model.load_state_dict(
            torch.load(args.pretrained_weights,
                       map_location=torch.device(device)))
    else:
        wrapper.train(train_dataloader, val_dataloader, args.epochs, optimizer,
                      lr0)
        torch.save(model.state_dict(), "YOLOv3-gate-prepruned.pt")

    with torch.no_grad():
        pre_prune_mAP, _, _ = wrapper.test(val_dataloader,
                                           img_size=config['image_size'],
                                           batch_size=args.batch_size)

    prune_perc = 0. if args.start_at_prune_rate is None else args.start_at_prune_rate
    prune_iter = 0
    curr_mAP = pre_prune_mAP

    if args.tensorboard:
        writer.add_scalar('prune/accuracy', curr_mAP, prune_iter)
        writer.add_scalar('prune/percentage', prune_perc, prune_iter)

        for name, param in wrapper.model.named_parameters():
            if 'bn' not in name:
                writer.add_histogram(f'prune/preprune/{name}', param,
                                     prune_iter)

    thresh_reached, _ = reached_threshold(args.prune_threshold, curr_mAP,
                                          pre_prune_mAP)
    while (not thresh_reached):
        prune_iter += 1
        prune_perc += 5.
        masks = weight_prune(model, prune_perc)
        model.set_mask(masks)

        print(
            f"Just pruned with prune_perc={prune_perc}, now has {prune_rate(model, verbose=False)}% zeros"
        )

        if not args.no_retrain:
            print(f"Retraining at prune percentage {prune_perc}..")
            curr_mAP, best_weights = wrapper.train(train_dataloader,
                                                   val_dataloader, 3,
                                                   optimizer, lr0)

            print("Loading best weights from training epochs..")
            model.load_state_dict(best_weights)

            print(
                f"Just finished training with prune_perc={prune_perc}, now has {prune_rate(model, verbose=False)}% zeros"
            )
        else:
            with torch.no_grad():
                curr_mAP, _, _ = wrapper.test(val_dataloader,
                                              img_size=config['image_size'],
                                              batch_size=args.batch_size)

        if args.tensorboard:
            writer.add_scalar('prune/accuracy', curr_mAP, prune_iter)
            writer.add_scalar('prune/percentage', prune_perc, prune_iter)

        thresh_reached, diff = reached_threshold(args.prune_threshold,
                                                 curr_mAP, pre_prune_mAP)

        print(f"mAP achieved: {curr_mAP}")
        print(f"Change in mAP: {diff}")

    prune_perc = prune_rate(model)

    if (args.save_model):
        #torch.save(model.state_dict(), f'{config["name"]}-pruned-{datetime.datetime.now().strftime("%Y%m%d%H%M")}.pt')
        #torch.save(model.state_dict(), "YOLOv3-prune-perc-" + str(prune_perc) + ".pt")
        torch.save(model.state_dict(),
                   "YOLOv3-gate-pruned-modelcompression.pt")

    if args.tensorboard:
        for name, param in wrapper.model.named_parameters():
            if 'weight' in name:
                writer.add_histogram(f'prune/postprune/{name}', param,
                                     prune_iter + 1)

    print(f"Pruned model: {config['name']}")
    print(f"Pre-pruning mAP: {pre_prune_mAP}")
    print(f"Post-pruning mAP: {curr_mAP}")
    print(f"Percentage of zeroes: {prune_perc}")

    return wrapper
Пример #2
0
def frcnn_config(config, args):
    classes = (
        '__background__',  # always index 0
        'aeroplane',
        'bicycle',
        'bird',
        'boat',
        'bottle',
        'bus',
        'car',
        'cat',
        'chair',
        'cow',
        'diningtable',
        'dog',
        'horse',
        'motorbike',
        'person',
        'pottedplant',
        'sheep',
        'sofa',
        'train',
        'tvmonitor')

    model = config['model'](
        classes
        # model_path = args.pretrained_weights
    )

    model.create_architecture()

    wrapper = FasterRCNNWrapper('cpu' if args.no_cuda else 'cuda:0', model)

    if args.tensorboard:
        writer = SummaryWriter()

    if args.pretrained_weights:
        print("Loading weights ", args.pretrained_weights)
        state_dict = torch.load(args.pretrained_weights,
                                map_location=torch.device('cuda:0'))

        if 'model' in state_dict.keys():
            state_dict = state_dict['model']

        model.load_state_dict(state_dict)
    else:
        wrapper.train(args.batch_size, args.lr, args.epochs)

    pre_prune_mAP = wrapper.test()
    # pre_prune_mAP = 0.6772

    prune_perc = 0. if args.start_at_prune_rate is None else args.start_at_prune_rate
    prune_iter = 0
    curr_mAP = pre_prune_mAP

    if args.tensorboard:
        writer.add_scalar('prune/accuracy', curr_mAP, prune_iter)
        writer.add_scalar('prune/percentage', prune_perc, prune_iter)

        for name, param in wrapper.model.named_parameters():
            if 'bn' not in name:
                writer.add_histogram(f'prune/preprune/{name}', param,
                                     prune_iter)

    thresh_reached, _ = reached_threshold(args.prune_threshold, curr_mAP,
                                          pre_prune_mAP)
    while not thresh_reached:
        prune_iter += 1
        prune_perc += 5.
        masks = weight_prune(model, prune_perc)
        model.set_mask(masks)

        if not args.no_retrain:
            print(f"Retraining at prune percentage {prune_perc}..")
            curr_mAP, best_weights = wrapper.train(args.batch_size, args.lr,
                                                   args.epochs)

            print("Loading best weights from epoch at mAP ", curr_mAP)
            model.load_state_dict(best_weights)

        else:
            with torch.no_grad():
                curr_mAP = wrapper.test()

        if args.tensorboard:
            writer.add_scalar('prune/accuracy', curr_mAP, prune_iter)
            writer.add_scalar('prune/percentage', prune_perc, prune_iter)

        thresh_reached, diff = reached_threshold(args.prune_threshold,
                                                 curr_mAP, pre_prune_mAP)

        print(f"mAP achieved: {curr_mAP}")
        print(f"Change in mAP: {curr_mAP - pre_prune_mAP}")

    prune_perc = prune_rate(model)

    if (args.save_model):
        torch.save(
            model.state_dict(),
            f'{config["name"]}-pruned-{datetime.datetime.now().strftime("%Y%m%d%H%M")}.pt'
        )

    if args.tensorboard:
        for name, param in wrapper.model.named_parameters():
            if 'weight' in name:
                writer.add_histogram(f'prune/postprune/{name}', param,
                                     prune_iter + 1)

    print(f"Pruned model: {config['name']}")
    print(f"Pre-pruning mAP: {pre_prune_mAP}")
    print(f"Post-pruning mAP: {curr_mAP}")
    print(f"Percentage of zeroes: {prune_perc}")

    return wrapper
Пример #3
0
def classifier_config(config, args):
    model = config['model']()
    device = 'cpu' if args.no_cuda else 'cuda:0'

    if args.tensorboard:
        writer = SummaryWriter()

    train_data = test_data = config['dataset']('./data',
                                               train=True,
                                               download=True,
                                               transform=transforms.Compose(
                                                   config['transforms']))

    test_data = config['dataset']('./data',
                                  train=False,
                                  download=True,
                                  transform=transforms.Compose(
                                      config['transforms']))

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=1)
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              num_workers=1)
    optimizer = config['optimizer'](model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum)

    wrapper = Classifier(model, device, train_loader, test_loader)

    if (args.pretrained_weights):
        print("Loading pretrained weights..")
        model.load_state_dict(
            torch.load(args.pretrained_weights,
                       map_location=torch.device(device)))
    else:
        wrapper.train(args.log_interval, optimizer, args.epochs,
                      config['loss_fn'])

    pre_prune_accuracy = wrapper.test(config['loss_fn'])
    prune_perc = 0. if args.start_at_prune_rate is None else args.start_at_prune_rate
    prune_iter = 0
    curr_accuracy = pre_prune_accuracy

    if args.tensorboard:
        writer.add_scalar('prune/accuracy', curr_accuracy, prune_iter)
        writer.add_scalar('prune/percentage', prune_perc, prune_iter)

        for name, param in wrapper.model.named_parameters():
            if 'bn' not in name:
                writer.add_histogram(f'prune/preprune/{name}', param,
                                     prune_iter)

    thresh_reached, _ = reached_threshold(args.prune_threshold, curr_accuracy,
                                          pre_prune_accuracy)
    while not thresh_reached:
        print(f"Testing at prune percentage {prune_perc}..")
        curr_accuracy = wrapper.test(config["loss_fn"])

        prune_iter += 1
        prune_perc += 5.
        # masks = weight_prune(model, prune_perc)
        masks = weight_prune(model, prune_perc, layerwise_thresh=True)
        model.set_mask(masks)

        if not args.no_retrain:
            print(f"Retraining at prune percentage {prune_perc}..")
            curr_accuracy, best_weights = wrapper.train(
                args.log_interval, optimizer, args.epochs, config['loss_fn'])

            print("Loading best weights from training epochs..")
            model.load_state_dict(best_weights)
        else:
            with torch.no_grad():
                curr_accuracy = wrapper.test(config['loss_fn'])

        if args.tensorboard:
            writer.add_scalar('prune/accuracy', curr_accuracy, prune_iter)
            writer.add_scalar('prune/percentage', prune_perc, prune_iter)

        thresh_reached, diff = reached_threshold(args.prune_threshold,
                                                 curr_accuracy,
                                                 pre_prune_accuracy)

        print(f"Accuracy achieved: {curr_accuracy}")
        print(f"Change in accuracy: {diff}")

    prune_perc = prune_rate(model)

    if (args.save_model):
        torch.save(
            model.state_dict(),
            f'./models/{config["name"]}-pruned-{datetime.datetime.now().strftime("%Y%m%d%H%M")}.pt'
        )

    if args.tensorboard:
        for name, param in wrapper.model.named_parameters():
            if 'weight' in name:
                writer.add_histogram(f'prune/postprune/{name}', param,
                                     prune_iter + 1)

    print(f"Pruned model: {config['name']}")
    print(f"Pre-pruning accuracy: {pre_prune_accuracy}")
    print(f"Post-pruning accuracy: {curr_accuracy}")
    print(f"Percentage of zeroes: {prune_perc}")

    return wrapper
def main():
    global args, best_prec1
    args = parser.parse_args()
    pruning = False
    chkpoint = False

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        # model = models.__dict__[args.arch](pretrained=True)
        model = alexnet(pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        # model = models.__dict__[args.arch]()
        model = alexnet(pretrained=False)

    if not args.distributed:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
    else:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            params = {
                k: v
                for k, v in checkpoint['state_dict'].items() if 'mask' not in k
            }
            mask_params = {
                k: v
                for k, v in checkpoint['state_dict'].items() if 'mask' in k
            }
            args.start_epoch = checkpoint['epoch']
            # saved_iter = checkpoint['iter']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(params)
            model.set_masks(list(mask_params.values()))
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            prune_rate(model)
            chkpoint = True
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if args.prune > 0 and not chkpoint:
        # prune
        print("=> pruning...")
        masks = weight_prune(model, args.prune)
        model.set_masks(masks)
        pruning = True

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'ilsvrc12_train_lmdb_224_pytorch')
    valdir = os.path.join(args.data, 'ilsvrc12_val_lmdb_224_pytorch')
    # traindir = os.path.join(args.data, 'ILSVRC2012_img_train')
    # valdir = os.path.join(args.data, 'ILSVRC2012_img_val_sorted')
    # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
    # std=[0.229, 0.224, 0.225])

    # train_dataset = datasets.ImageFolder(
    # traindir,
    # transforms.Compose([
    # transforms.RandomResizedCrop(224),
    # transforms.RandomHorizontalFlip(),
    # transforms.ToTensor(),
    # normalize,
    # ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    # train_loader = torch.utils.data.DataLoader(
    # train_dataset, batch_size=args.batch_size, shuffle=(
    # train_sampler is None),
    # num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    train_loader = Loader('train',
                          traindir,
                          batch_size=args.batch_size,
                          num_workers=args.workers,
                          cuda=True)
    val_loader = Loader('val',
                        valdir,
                        batch_size=args.batch_size,
                        num_workers=args.workers,
                        cuda=True)

    # val_loader = torch.utils.data.DataLoader(
    # datasets.ImageFolder(valdir, transforms.Compose([
    # transforms.Resize(256),
    # transforms.CenterCrop(224),
    # transforms.ToTensor(),
    # normalize,
    # ])),
    # batch_size=args.batch_size, shuffle=False,
    # num_workers=args.workers, pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return
    if pruning and not chkpoint:
        # Prune weights validation
        print("--- {}% parameters pruned ---".format(args.prune))
        validate(val_loader, model, criterion)

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
                'iter': 0,
            },
            is_best,
            path=args.logfolder)

    print("--- After retraining ---")
    prune_rate(model)
    torch.save(model.state_dict(),
               os.path.join(args.logfolder, 'alexnet_pruned.pkl'))
                              transform=transforms.ToTensor())
loader_test = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=param['test_batch_size'],
                                          shuffle=True)

# Load the pretrained model
net = MLP()
net.load_state_dict(torch.load('models/mlp_pretrained.pkl'))
if torch.cuda.is_available():
    print('CUDA ensabled.')
    net.cuda()
print("--- Pretrained network loaded ---")
test(net, loader_test)

# prune the weights
masks = weight_prune(net, param['pruning_perc'])
net.set_masks(masks)
print("--- {}% parameters pruned ---".format(param['pruning_perc']))
test(net, loader_test)

# Retraining
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.RMSprop(net.parameters(),
                                lr=param['learning_rate'],
                                weight_decay=param['weight_decay'])

train(net, criterion, optimizer, param, loader_train)

# Check accuracy and nonzeros weights in each layer
print("--- After retraining ---")
test(net, loader_test)
Пример #6
0
        enc = encoder(ip)
        enc = enc + torch.randn_like(enc, device=device) / scal
        op = decoder(enc)

        errs[i] = error_rate(op, labels)

    plt.semilogy(xx, errs + 1 / 10**hp.e_prec, label='All weights')

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=hp.lr)

print("--- Pretrained network loaded ---")
test()

# prune the weights
masks = weight_prune(net, hp.pp)
i = 0
for part in net:  # part in [encoder,decoder]
    for p in part[::2]:  # conveniently skips biases
        p.set_mask(masks[i])
        i += 1
print("--- {}% parameters pruned ---".format(hp.pp))
test()

if hp.plot:
    for i, snr in enumerate(snrs):
        print(i)
        scal = np.sqrt(snr * 2 * hp.k / hp.n)

        labels, ip = generate_input(amt=10**hp.e_prec)
        enc = encoder(ip)