Exemplo n.º 1
0
 def __init__(self, bits_A, bits_E, A_mode="nearest", E_mode="nearest"):
     super(WAGEQuantizer, self).__init__()
     self.activate_number = FixedPoint(
         wl=bits_A, fl=bits_A -
         1, clamp=True, symmetric=True) if bits_A != -1 else None
     self.error_number = FixedPoint(
         wl=bits_E, fl=bits_E -
         1, clamp=True, symmetric=True) if bits_E != -1 else None
     self.quantizer = quantizer(forward_number=self.activate_number,
                                forward_rounding=A_mode,
                                backward_number=self.error_number,
                                backward_rounding=E_mode,
                                clamping_grad_zero=True,
                                backward_hooks=[shift])
Exemplo n.º 2
0
loaders = get_data(args.dataset, args.data_path, args.batch_size,
                   args.val_ratio, args.num_workers)
if args.dataset == "CIFAR10":
    num_classes = 10
elif args.dataset == "IMAGENET12":
    num_classes = 1000

quantizers = {}
for num in num_types:
    num_rounding = getattr(args, "{}_rounding".format(num))
    num_man = getattr(args, "{}_man".format(num))
    num_exp = getattr(args, "{}_exp".format(num))
    number = FloatingPoint(exp=num_exp, man=num_man)
    print("{}: {} rounding, {}".format(num, num_rounding, number))
    quantizers[num] = quantizer(forward_number=number,
                                forward_rounding=num_rounding)
# Build model
print("Model: {}".format(args.model))
model_cfg = getattr(models, args.model)
if "LP" in args.model:
    activate_number = FloatingPoint(exp=args.activate_exp,
                                    man=args.activate_man)
    error_number = FloatingPoint(exp=args.error_exp, man=args.error_man)
    print("activation: {}, {}".format(args.activate_rounding, activate_number))
    print("error: {}, {}".format(args.error_rounding, error_number))
    make_quant = lambda: Quantizer(activate_number, error_number, args.
                                   activate_rounding, args.error_rounding)
    model_cfg.kwargs.update({"quant": make_quant})

model = model_cfg.base(*model_cfg.args,
                       num_classes=num_classes,
Exemplo n.º 3
0
def main():
    DATA_DIR = './data'
    dataset = cifar10(root=DATA_DIR)
    timer = Timer()
    print('Preprocessing training data')
    transforms = [
        partial(normalise,
                mean=np.array(cifar10_mean, dtype=np.float32),
                std=np.array(cifar10_std, dtype=np.float32)),
        partial(transpose, source='NHWC', target='NCHW'),
    ]
    train_set = list(
        zip(*preprocess(dataset['train'], [partial(pad, border=4)] +
                        transforms).values()))
    print(f'Finished in {timer():.2} seconds')
    print('Preprocessing test data')
    valid_set = list(zip(*preprocess(dataset['valid'], transforms).values()))
    print(f'Finished in {timer():.2} seconds')

    epochs = args.epochs
    lr_schedule = PiecewiseLinear([0, 5, epochs], [0, 0.4, 0])
    batch_size = 512
    train_transforms = [Crop(32, 32), FlipLR(), Cutout(8, 8)]
    N_runs = 1

    train_batches = DataLoader(Transform(train_set, train_transforms),
                               batch_size,
                               shuffle=True,
                               set_random_choices=True,
                               drop_last=True)
    valid_batches = DataLoader(valid_set,
                               batch_size,
                               shuffle=False,
                               drop_last=False)
    lr = lambda step: lr_schedule(step / len(train_batches)) / batch_size

    ##### Fetch the LP attributes ########
    if args.LP:
        quantizers = {}
        for num in num_types:
            num_rounding = getattr(args, "{}_rounding".format(num))
            num_man = getattr(args, "{}_man".format(num))
            num_exp = getattr(args, "{}_exp".format(num))
            number = FloatingPoint(exp=num_exp, man=num_man)
            logger.info("{}: {} rounding, {}".format(num, num_rounding,
                                                     number))
            quantizers[num] = quantizer(forward_number=number,
                                        forward_rounding=num_rounding)

    summaries = []
    for i in range(N_runs):
        print(f'Starting Run {i} at {localtime()}')
        model = Network(
            net(gamma=args.TD_gamma,
                alpha=args.TD_alpha,
                block_size=args.block_size)).to(device)

        Hooks_input = utils.add_input_record_Hook(model)
        opts = [
            SGD(
                trainable_params(model).values(), {
                    'lr': lr,
                    'weight_decay': Const(5e-4 * batch_size),
                    'momentum': Const(0.9)
                })
        ]
        logs, state = Table(), {MODEL: model, LOSS: x_ent_loss, OPTS: opts}

        activation_sparsity = 0.0
        asparse = []
        wsparse = []
        for epoch in range(epochs):
            td_gamma, td_alpha = update_gamma_alpha(epoch, model)

            weight_sparsity = utils.get_weight_sparsity(model)
            logs.append(
                union({'epoch': epoch + 1}, {'lr': lr_schedule(epoch + 1)},
                      {'gamma': td_gamma}, {'alpha': td_alpha},
                      {'wspar': weight_sparsity},
                      {'aspar': round(activation_sparsity, 4)},
                      train_epoch(state, Timer(torch.cuda.synchronize),
                                  train_batches, valid_batches)))
            activation_sparsity = utils.get_activation_sparsity(
                Hooks_input).item()

            asparse.append(activation_sparsity)
            wsparse.append(weight_sparsity)

        print(
            f'Avg weight sparsity: {np.mean(wsparse)} | Avg act sparsity: {np.mean(asparse)}'
        )

    logs.df().query(f'epoch=={epochs}')[['train_acc', 'valid_acc']].describe()

    if args.save_file is not None:
        torch.save({
            'state_dict': model.state_dict(),
            'args': args
        }, os.path.join('checkpoint', args.save_file))