예제 #1
0
def classifier_config(config, args):
    model = config['model']()

    device = 'cuda:1' if not args.no_cuda else 'cpu'

    train_data = config['dataset']('./data',
                                   train=True,
                                   download=True,
                                   transform=transforms.Compose(
                                       config['transforms']))

    test_data = config['dataset']('./data',
                                  train=False,
                                  download=True,
                                  transform=transforms.Compose(
                                      config['transforms']))

    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=8,
                                              shuffle=True,
                                              num_workers=1,
                                              pin_memory=True)

    wrapper = Classifier(model, device, None, test_loader)

    model.load_state_dict(
        torch.load(args.pretrained_weights, map_location=device))

    wrapper.test(config["loss_fn"])

    print("Started quantizing")
    start_time = datetime.datetime.now()

    quantize_k_means(model, show_figures=True)

    prune_rate(model)

    end_time = datetime.datetime.now()
    print(f"Finished quantizing. Time taken: {end_time - start_time}")

    wrapper.test(config["loss_fn"])

    return wrapper
예제 #2
0
def yolo_config(config, args):
    config = [x for x in configurations if x['name'] == 'YOLOv3'][0]
    model = config['model'](config['config_path'])

    device = 'cuda:1' if not args.no_cuda else 'cpu'

    wrapper = YoloWrapper(device, model)
    lr0 = 0.001
    optimizer = config['optimizer'](filter(lambda x: x.requires_grad,
                                           model.parameters()),
                                    lr=lr0,
                                    momentum=0.5)

    print("Loading dataloaders..")
    val_dataloader = LoadImagesAndLabels(config['datasets']['test'],
                                         batch_size=32,
                                         img_size=config['image_size'])

    model.to(device)

    print("Loading pretrained weights..")
    model.load_state_dict(
        torch.load(args.pretrained_weights, map_location=device))

    print("Pre-quantized percentage of zeros..")

    prune_rate(model)

    # with torch.no_grad():
    #     mAP, _, _ = wrapper.test(val_dataloader, img_size=config['image_size'], batch_size=32)
    #     print("Accuracy: ", mAP)

    print("Quantizing..")
    quantize_k_means(model)

    prune_rate(model)

    with torch.no_grad():
        mAP, _, _ = wrapper.test(val_dataloader,
                                 img_size=config['image_size'],
                                 batch_size=32)
        print("Accuracy: ", mAP)

    print("Post-quantize percentage of zeros..")

    prune_rate(model)

    if (args.save_model):
        torch.save(
            model.state_dict(),
            f'./models/{config["name"]}-quantized-{datetime.datetime.now().strftime("%Y%m%d%H%M")}.pt'
        )

    return wrapper
예제 #3
0
def yolo_config(config, args):
    # if (args.prune_threshold > 0.005):
    #     print("WARNING: Prune threshold seems too large.")
    #     if input("Input y if you are sure you want to continue.") != 'y': return

    device = 'cpu' if args.no_cuda else 'cuda:0'
    model = config['model'](config['config_path'], device=device)
    wrapper = YoloWrapper(device, model)
    lr0 = 0.001
    # lr0 = args.lr
    optimizer = config['optimizer'](filter(lambda x: x.requires_grad,
                                           model.parameters()),
                                    lr=lr0,
                                    momentum=args.momentum)
    writer = SummaryWriter()

    print("Loading dataloaders..")
    train_dataloader = LoadImagesAndLabels(config['datasets']['train'],
                                           batch_size=args.batch_size,
                                           img_size=config['image_size'])
    val_dataloader = LoadImagesAndLabels(config['datasets']['test'],
                                         batch_size=args.batch_size,
                                         img_size=config['image_size'])

    if (args.pretrained_weights):
        model.load_state_dict(
            torch.load(args.pretrained_weights,
                       map_location=torch.device(device)))
    else:
        wrapper.train(train_dataloader, val_dataloader, args.epochs, optimizer,
                      lr0)
        torch.save(model.state_dict(), "YOLOv3-gate-prepruned.pt")

    with torch.no_grad():
        pre_prune_mAP, _, _ = wrapper.test(val_dataloader,
                                           img_size=config['image_size'],
                                           batch_size=args.batch_size)

    prune_perc = 0. if args.start_at_prune_rate is None else args.start_at_prune_rate
    prune_iter = 0
    curr_mAP = pre_prune_mAP

    if args.tensorboard:
        writer.add_scalar('prune/accuracy', curr_mAP, prune_iter)
        writer.add_scalar('prune/percentage', prune_perc, prune_iter)

        for name, param in wrapper.model.named_parameters():
            if 'bn' not in name:
                writer.add_histogram(f'prune/preprune/{name}', param,
                                     prune_iter)

    thresh_reached, _ = reached_threshold(args.prune_threshold, curr_mAP,
                                          pre_prune_mAP)
    while (not thresh_reached):
        prune_iter += 1
        prune_perc += 5.
        masks = weight_prune(model, prune_perc)
        model.set_mask(masks)

        print(
            f"Just pruned with prune_perc={prune_perc}, now has {prune_rate(model, verbose=False)}% zeros"
        )

        if not args.no_retrain:
            print(f"Retraining at prune percentage {prune_perc}..")
            curr_mAP, best_weights = wrapper.train(train_dataloader,
                                                   val_dataloader, 3,
                                                   optimizer, lr0)

            print("Loading best weights from training epochs..")
            model.load_state_dict(best_weights)

            print(
                f"Just finished training with prune_perc={prune_perc}, now has {prune_rate(model, verbose=False)}% zeros"
            )
        else:
            with torch.no_grad():
                curr_mAP, _, _ = wrapper.test(val_dataloader,
                                              img_size=config['image_size'],
                                              batch_size=args.batch_size)

        if args.tensorboard:
            writer.add_scalar('prune/accuracy', curr_mAP, prune_iter)
            writer.add_scalar('prune/percentage', prune_perc, prune_iter)

        thresh_reached, diff = reached_threshold(args.prune_threshold,
                                                 curr_mAP, pre_prune_mAP)

        print(f"mAP achieved: {curr_mAP}")
        print(f"Change in mAP: {diff}")

    prune_perc = prune_rate(model)

    if (args.save_model):
        #torch.save(model.state_dict(), f'{config["name"]}-pruned-{datetime.datetime.now().strftime("%Y%m%d%H%M")}.pt')
        #torch.save(model.state_dict(), "YOLOv3-prune-perc-" + str(prune_perc) + ".pt")
        torch.save(model.state_dict(),
                   "YOLOv3-gate-pruned-modelcompression.pt")

    if args.tensorboard:
        for name, param in wrapper.model.named_parameters():
            if 'weight' in name:
                writer.add_histogram(f'prune/postprune/{name}', param,
                                     prune_iter + 1)

    print(f"Pruned model: {config['name']}")
    print(f"Pre-pruning mAP: {pre_prune_mAP}")
    print(f"Post-pruning mAP: {curr_mAP}")
    print(f"Percentage of zeroes: {prune_perc}")

    return wrapper
예제 #4
0
def frcnn_config(config, args):
    classes = (
        '__background__',  # always index 0
        'aeroplane',
        'bicycle',
        'bird',
        'boat',
        'bottle',
        'bus',
        'car',
        'cat',
        'chair',
        'cow',
        'diningtable',
        'dog',
        'horse',
        'motorbike',
        'person',
        'pottedplant',
        'sheep',
        'sofa',
        'train',
        'tvmonitor')

    model = config['model'](
        classes
        # model_path = args.pretrained_weights
    )

    model.create_architecture()

    wrapper = FasterRCNNWrapper('cpu' if args.no_cuda else 'cuda:0', model)

    if args.tensorboard:
        writer = SummaryWriter()

    if args.pretrained_weights:
        print("Loading weights ", args.pretrained_weights)
        state_dict = torch.load(args.pretrained_weights,
                                map_location=torch.device('cuda:0'))

        if 'model' in state_dict.keys():
            state_dict = state_dict['model']

        model.load_state_dict(state_dict)
    else:
        wrapper.train(args.batch_size, args.lr, args.epochs)

    pre_prune_mAP = wrapper.test()
    # pre_prune_mAP = 0.6772

    prune_perc = 0. if args.start_at_prune_rate is None else args.start_at_prune_rate
    prune_iter = 0
    curr_mAP = pre_prune_mAP

    if args.tensorboard:
        writer.add_scalar('prune/accuracy', curr_mAP, prune_iter)
        writer.add_scalar('prune/percentage', prune_perc, prune_iter)

        for name, param in wrapper.model.named_parameters():
            if 'bn' not in name:
                writer.add_histogram(f'prune/preprune/{name}', param,
                                     prune_iter)

    thresh_reached, _ = reached_threshold(args.prune_threshold, curr_mAP,
                                          pre_prune_mAP)
    while not thresh_reached:
        prune_iter += 1
        prune_perc += 5.
        masks = weight_prune(model, prune_perc)
        model.set_mask(masks)

        if not args.no_retrain:
            print(f"Retraining at prune percentage {prune_perc}..")
            curr_mAP, best_weights = wrapper.train(args.batch_size, args.lr,
                                                   args.epochs)

            print("Loading best weights from epoch at mAP ", curr_mAP)
            model.load_state_dict(best_weights)

        else:
            with torch.no_grad():
                curr_mAP = wrapper.test()

        if args.tensorboard:
            writer.add_scalar('prune/accuracy', curr_mAP, prune_iter)
            writer.add_scalar('prune/percentage', prune_perc, prune_iter)

        thresh_reached, diff = reached_threshold(args.prune_threshold,
                                                 curr_mAP, pre_prune_mAP)

        print(f"mAP achieved: {curr_mAP}")
        print(f"Change in mAP: {curr_mAP - pre_prune_mAP}")

    prune_perc = prune_rate(model)

    if (args.save_model):
        torch.save(
            model.state_dict(),
            f'{config["name"]}-pruned-{datetime.datetime.now().strftime("%Y%m%d%H%M")}.pt'
        )

    if args.tensorboard:
        for name, param in wrapper.model.named_parameters():
            if 'weight' in name:
                writer.add_histogram(f'prune/postprune/{name}', param,
                                     prune_iter + 1)

    print(f"Pruned model: {config['name']}")
    print(f"Pre-pruning mAP: {pre_prune_mAP}")
    print(f"Post-pruning mAP: {curr_mAP}")
    print(f"Percentage of zeroes: {prune_perc}")

    return wrapper
예제 #5
0
def classifier_config(config, args):
    model = config['model']()
    device = 'cpu' if args.no_cuda else 'cuda:0'

    if args.tensorboard:
        writer = SummaryWriter()

    train_data = test_data = config['dataset']('./data',
                                               train=True,
                                               download=True,
                                               transform=transforms.Compose(
                                                   config['transforms']))

    test_data = config['dataset']('./data',
                                  train=False,
                                  download=True,
                                  transform=transforms.Compose(
                                      config['transforms']))

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=1)
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              num_workers=1)
    optimizer = config['optimizer'](model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum)

    wrapper = Classifier(model, device, train_loader, test_loader)

    if (args.pretrained_weights):
        print("Loading pretrained weights..")
        model.load_state_dict(
            torch.load(args.pretrained_weights,
                       map_location=torch.device(device)))
    else:
        wrapper.train(args.log_interval, optimizer, args.epochs,
                      config['loss_fn'])

    pre_prune_accuracy = wrapper.test(config['loss_fn'])
    prune_perc = 0. if args.start_at_prune_rate is None else args.start_at_prune_rate
    prune_iter = 0
    curr_accuracy = pre_prune_accuracy

    if args.tensorboard:
        writer.add_scalar('prune/accuracy', curr_accuracy, prune_iter)
        writer.add_scalar('prune/percentage', prune_perc, prune_iter)

        for name, param in wrapper.model.named_parameters():
            if 'bn' not in name:
                writer.add_histogram(f'prune/preprune/{name}', param,
                                     prune_iter)

    thresh_reached, _ = reached_threshold(args.prune_threshold, curr_accuracy,
                                          pre_prune_accuracy)
    while not thresh_reached:
        print(f"Testing at prune percentage {prune_perc}..")
        curr_accuracy = wrapper.test(config["loss_fn"])

        prune_iter += 1
        prune_perc += 5.
        # masks = weight_prune(model, prune_perc)
        masks = weight_prune(model, prune_perc, layerwise_thresh=True)
        model.set_mask(masks)

        if not args.no_retrain:
            print(f"Retraining at prune percentage {prune_perc}..")
            curr_accuracy, best_weights = wrapper.train(
                args.log_interval, optimizer, args.epochs, config['loss_fn'])

            print("Loading best weights from training epochs..")
            model.load_state_dict(best_weights)
        else:
            with torch.no_grad():
                curr_accuracy = wrapper.test(config['loss_fn'])

        if args.tensorboard:
            writer.add_scalar('prune/accuracy', curr_accuracy, prune_iter)
            writer.add_scalar('prune/percentage', prune_perc, prune_iter)

        thresh_reached, diff = reached_threshold(args.prune_threshold,
                                                 curr_accuracy,
                                                 pre_prune_accuracy)

        print(f"Accuracy achieved: {curr_accuracy}")
        print(f"Change in accuracy: {diff}")

    prune_perc = prune_rate(model)

    if (args.save_model):
        torch.save(
            model.state_dict(),
            f'./models/{config["name"]}-pruned-{datetime.datetime.now().strftime("%Y%m%d%H%M")}.pt'
        )

    if args.tensorboard:
        for name, param in wrapper.model.named_parameters():
            if 'weight' in name:
                writer.add_histogram(f'prune/postprune/{name}', param,
                                     prune_iter + 1)

    print(f"Pruned model: {config['name']}")
    print(f"Pre-pruning accuracy: {pre_prune_accuracy}")
    print(f"Post-pruning accuracy: {curr_accuracy}")
    print(f"Percentage of zeroes: {prune_perc}")

    return wrapper