Esempio n. 1
0
def init_model(args, num_classes):
    """Initialise model"""
    architecture = getattr(models, args.model)

    if args.curve is None:
        model = architecture.base(num_classes=num_classes,
                                  **architecture.kwargs)
    else:
        curve = getattr(curves, args.curve)
        model = curves.CurveNet(
            num_classes,
            curve,
            architecture.curve,
            args.num_bends,
            args.fix_start,
            args.fix_end,
            architecture_kwargs=architecture.kwargs,
        )
        base_model = None
        if args.resume is None:
            for path, k in [(args.init_start, 0),
                            (args.init_end, args.num_bends - 1)]:
                if path is not None:
                    if base_model is None:
                        base_model = architecture.base(num_classes=num_classes,
                                                       **architecture.kwargs)
                    checkpoint = torch.load(path)
                    print("Loading %s as point #%d" % (path, k))
                    base_model.load_state_dict(checkpoint["model_state"])
                    model.import_base_parameters(base_model, k)
            if args.init_linear:
                print("Linear initialization.")
                model.init_linear()
    model.cuda()
    return model
def init_model(args, num_classes):
    """Initialise model"""
    architecture = getattr(models, args.model)
    curve = getattr(curves, args.curve)
    model = curves.CurveNet(
        num_classes,
        curve,
        architecture.curve,
        args.num_bends,
        architecture_kwargs=architecture.kwargs,
    )
    model.cuda()
    checkpoint = torch.load(args.ckpt)
    model.load_state_dict(checkpoint["model_state"])
    return model
Esempio n. 3
0
loaders, num_classes = data.loaders(args.dataset,
                                    args.data_path,
                                    args.batch_size,
                                    args.num_workers,
                                    args.transform,
                                    args.use_test,
                                    shuffle_train=False)

architecture = getattr(models, args.model)
curve = getattr(curves, args.curve)

curve_model = curves.CurveNet(
    num_classes,
    curve,
    architecture.curve,
    args.num_bends,
    architecture_kwargs=architecture.kwargs,
)
curve_model.cuda()

checkpoint = torch.load(args.ckpt)
curve_model.load_state_dict(checkpoint['model_state'])

criterion = F.cross_entropy
regularizer = utils.l2_regularizer(args.wd)


def get_xy(point, origin, vector_x, vector_y):
    return np.array(
        [np.dot(point - origin, vector_x),
def evaluate_curve(dir='/tmp/curve/',
                   ckpt=None,
                   num_points=61,
                   dataset='CIFAR10',
                   use_test=True,
                   transform='VGG',
                   data_path=None,
                   batch_size=128,
                   num_workers=4,
                   model_type=None,
                   curve_type=None,
                   num_bends=3,
                   wd=1e-4):
    args = EvalCurveArgSet(dir=dir,
                           ckpt=ckpt,
                           num_points=num_points,
                           dataset=dataset,
                           use_test=use_test,
                           transform=transform,
                           data_path=data_path,
                           batch_size=batch_size,
                           num_workers=num_workers,
                           model=model_type,
                           curve=curve_type,
                           num_bends=num_bends,
                           wd=wd)

    if True:
        q = 1
        # parser = argparse.ArgumentParser(description='DNN curve evaluation')
        # parser.add_argument('--dir', type=str, default='/tmp/eval', metavar='DIR',
        #                     help='training directory (default: /tmp/eval)')
        #
        # parser.add_argument('--num_points', type=int, default=61, metavar='N',
        #                     help='number of points on the curve (default: 61)')
        #
        # parser.add_argument('--dataset', type=str, default='CIFAR10', metavar='DATASET',
        #                     help='dataset name (default: CIFAR10)')
        # parser.add_argument('--use_test', action='store_true',
        #                     help='switches between validation and test set (default: validation)')
        # parser.add_argument('--transform', type=str, default='VGG', metavar='TRANSFORM',
        #                     help='transform name (default: VGG)')
        # parser.add_argument('--data_path', type=str, default=None, metavar='PATH',
        #                     help='path to datasets location (default: None)')
        # parser.add_argument('--batch_size', type=int, default=128, metavar='N',
        #                     help='input batch size (default: 128)')
        # parser.add_argument('--num_workers', type=int, default=4, metavar='N',
        #                     help='number of workers (default: 4)')
        #
        # parser.add_argument('--model', type=str, default=None, metavar='MODEL',
        #                     help='model name (default: None)')
        # parser.add_argument('--curve', type=str, default=None, metavar='CURVE',
        #                     help='curve type to use (default: None)')
        # parser.add_argument('--num_bends', type=int, default=3, metavar='N',
        #                     help='number of curve bends (default: 3)')
        #
        # parser.add_argument('--ckpt', type=str, default=None, metavar='CKPT',
        #                     help='checkpoint to eval (default: None)')
        #
        # parser.add_argument('--wd', type=float, default=1e-4, metavar='WD',
        #                     help='weight decay (default: 1e-4)')

    # args = parser.parse_args()

    os.makedirs(args.dir, exist_ok=True)

    torch.backends.cudnn.benchmark = True

    loaders, num_classes = data.loaders(args.dataset,
                                        args.data_path,
                                        args.batch_size,
                                        args.num_workers,
                                        args.transform,
                                        args.use_test,
                                        shuffle_train=False)

    architecture = getattr(models, args.model)
    curve = getattr(curves, args.curve)
    model = curves.CurveNet(
        num_classes,
        curve,
        architecture.curve,
        args.num_bends,
        architecture_kwargs=architecture.kwargs,
    )
    model.cuda()
    checkpoint = torch.load(args.ckpt)
    model.load_state_dict(checkpoint['model_state'])

    criterion = F.cross_entropy
    regularizer = curves.l2_regularizer(args.wd)

    T = args.num_points
    ts = np.linspace(0.0, 1.0, T)
    tr_loss = np.zeros(T)
    tr_nll = np.zeros(T)
    tr_acc = np.zeros(T)
    te_loss = np.zeros(T)
    te_nll = np.zeros(T)
    te_acc = np.zeros(T)
    tr_err = np.zeros(T)
    te_err = np.zeros(T)
    dl = np.zeros(T)

    previous_weights = None

    columns = [
        't', 'Train loss', 'Train nll', 'Train error (%)', 'Test nll',
        'Test error (%)', 'Distance'
    ]

    t = torch.FloatTensor([0.0]).cuda()
    for i, t_value in enumerate(ts):
        t.data.fill_(t_value)
        weights = model.weights(t)
        if previous_weights is not None:
            dl[i] = np.sqrt(np.sum(np.square(weights - previous_weights)))
        previous_weights = weights.copy()

        utils.update_bn(loaders['train'], model, t=t)
        tr_res = utils.test(loaders['train'],
                            model,
                            criterion,
                            regularizer,
                            t=t)
        te_res = utils.test(loaders['test'],
                            model,
                            criterion,
                            regularizer,
                            t=t)
        tr_loss[i] = tr_res['loss']
        tr_nll[i] = tr_res['nll']
        tr_acc[i] = tr_res['accuracy']
        tr_err[i] = 100.0 - tr_acc[i]
        te_loss[i] = te_res['loss']
        te_nll[i] = te_res['nll']
        te_acc[i] = te_res['accuracy']
        te_err[i] = 100.0 - te_acc[i]

        values = [
            t, tr_loss[i], tr_nll[i], tr_err[i], te_nll[i], te_err[i], dl[i]
        ]
        table = tabulate.tabulate([values],
                                  columns,
                                  tablefmt='simple',
                                  floatfmt='10.4f')
        if i % 40 == 0:
            table = table.split('\n')
            table = '\n'.join([table[1]] + table)
        else:
            table = table.split('\n')[2]
        print(table)

    def stats(values, dl):
        min = np.min(values)
        max = np.max(values)
        avg = np.mean(values)
        int = np.sum(0.5 *
                     (values[:-1] + values[1:]) * dl[1:]) / np.sum(dl[1:])
        return min, max, avg, int

    tr_loss_min, tr_loss_max, tr_loss_avg, tr_loss_int = stats(tr_loss, dl)
    tr_nll_min, tr_nll_max, tr_nll_avg, tr_nll_int = stats(tr_nll, dl)
    tr_err_min, tr_err_max, tr_err_avg, tr_err_int = stats(tr_err, dl)

    te_loss_min, te_loss_max, te_loss_avg, te_loss_int = stats(te_loss, dl)
    te_nll_min, te_nll_max, te_nll_avg, te_nll_int = stats(te_nll, dl)
    te_err_min, te_err_max, te_err_avg, te_err_int = stats(te_err, dl)

    print('Length: %.2f' % np.sum(dl))
    print(
        tabulate.tabulate([
            [
                'train loss', tr_loss[0], tr_loss[-1], tr_loss_min,
                tr_loss_max, tr_loss_avg, tr_loss_int
            ],
            [
                'train error (%)', tr_err[0], tr_err[-1], tr_err_min,
                tr_err_max, tr_err_avg, tr_err_int
            ],
            [
                'test nll', te_nll[0], te_nll[-1], te_nll_min, te_nll_max,
                te_nll_avg, te_nll_int
            ],
            [
                'test error (%)', te_err[0], te_err[-1], te_err_min,
                te_err_max, te_err_avg, te_err_int
            ],
        ], ['', 'start', 'end', 'min', 'max', 'avg', 'int'],
                          tablefmt='simple',
                          floatfmt='10.4f'))

    np.savez(
        os.path.join(args.dir, 'curve.npz'),
        ts=ts,
        dl=dl,
        tr_loss=tr_loss,
        tr_loss_min=tr_loss_min,
        tr_loss_max=tr_loss_max,
        tr_loss_avg=tr_loss_avg,
        tr_loss_int=tr_loss_int,
        tr_nll=tr_nll,
        tr_nll_min=tr_nll_min,
        tr_nll_max=tr_nll_max,
        tr_nll_avg=tr_nll_avg,
        tr_nll_int=tr_nll_int,
        tr_acc=tr_acc,
        tr_err=tr_err,
        tr_err_min=tr_err_min,
        tr_err_max=tr_err_max,
        tr_err_avg=tr_err_avg,
        tr_err_int=tr_err_int,
        te_loss=te_loss,
        te_loss_min=te_loss_min,
        te_loss_max=te_loss_max,
        te_loss_avg=te_loss_avg,
        te_loss_int=te_loss_int,
        te_nll=te_nll,
        te_nll_min=te_nll_min,
        te_nll_max=te_nll_max,
        te_nll_avg=te_nll_avg,
        te_nll_int=te_nll_int,
        te_acc=te_acc,
        te_err=te_err,
        te_err_min=te_err_min,
        te_err_max=te_err_max,
        te_err_avg=te_err_avg,
        te_err_int=te_err_int,
    )
Esempio n. 5
0
def train_model(dir='/tmp/curve/',
                dataset='CIFAR10',
                use_test=True,
                transform='VGG',
                data_path=None,
                batch_size=128,
                num_workers=4,
                model_type=None,
                curve_type=None,
                num_bends=3,
                init_start=None,
                fix_start=True,
                init_end=None,
                fix_end=True,
                init_linear=True,
                resume=None,
                epochs=200,
                save_freq=50,
                lr=.01,
                momentum=.9,
                wd=1e-4,
                seed=1):
    args = TrainArgSet(dir=dir,
                       dataset=dataset,
                       use_test=use_test,
                       transform=transform,
                       data_path=data_path,
                       batch_size=batch_size,
                       num_workers=num_workers,
                       model=model_type,
                       curve=curve_type,
                       num_bends=num_bends,
                       init_start=init_start,
                       fix_start=fix_start,
                       init_end=init_end,
                       fix_end=fix_end,
                       init_linear=init_linear,
                       resume=resume,
                       epochs=epochs,
                       save_freq=save_freq,
                       lr=lr,
                       momentum=momentum,
                       wd=wd,
                       seed=seed)

    os.makedirs(args.dir, exist_ok=True)
    with open(os.path.join(args.dir, 'command.sh'), 'w') as f:
        f.write(' '.join(sys.argv))
        f.write('\n')

    torch.backends.cudnn.benchmark = True
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    loaders, num_classes = data.loaders(args.dataset, args.data_path,
                                        args.batch_size, args.num_workers,
                                        args.transform, args.use_test)

    architecture = getattr(models, args.model)

    if args.curve is None:
        model = architecture.base(num_classes=num_classes,
                                  **architecture.kwargs)
    else:
        curve = getattr(curves, args.curve)
        model = curves.CurveNet(
            num_classes,
            curve,
            architecture.curve,
            args.num_bends,
            args.fix_start,
            args.fix_end,
            architecture_kwargs=architecture.kwargs,
        )
        base_model = None
        if args.resume is None:
            for path, k in [(args.init_start, 0),
                            (args.init_end, args.num_bends - 1)]:
                if path is not None:
                    if base_model is None:
                        base_model = architecture.base(num_classes=num_classes,
                                                       **architecture.kwargs)
                    checkpoint = torch.load(path)
                    print('Loading %s as point #%d' % (path, k))
                    base_model.load_state_dict(checkpoint['model_state'])
                    model.import_base_parameters(base_model, k)
            if args.init_linear:
                print('Linear initialization.')
                model.init_linear()
    model.cuda()

    def learning_rate_schedule(base_lr, epoch, total_epochs):
        alpha = epoch / total_epochs
        if alpha <= 0.5:
            factor = 1.0
        elif alpha <= 0.9:
            factor = 1.0 - (alpha - 0.5) / 0.4 * 0.99
        else:
            factor = factor = .01 * (1 - ((alpha - .9) / .1))
        return factor * base_lr

    criterion = F.cross_entropy
    regularizer = None if args.curve is None else curves.l2_regularizer(
        args.wd)
    optimizer = torch.optim.SGD(
        filter(lambda param: param.requires_grad, model.parameters()),
        lr=args.lr,
        momentum=args.momentum,
        weight_decay=args.wd if args.curve is None else 0.0)

    start_epoch = 1
    if args.resume is not None:
        print('Resume training from %s' % args.resume)
        checkpoint = torch.load(args.resume)
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['model_state'])
        optimizer.load_state_dict(checkpoint['optimizer_state'])

    columns = ['ep', 'lr', 'tr_loss', 'tr_acc', 'te_nll', 'te_acc', 'time']

    utils.save_checkpoint(args.dir,
                          start_epoch - 1,
                          model_state=model.state_dict(),
                          optimizer_state=optimizer.state_dict())

    has_bn = utils.check_bn(model)
    test_res = {'loss': None, 'accuracy': None, 'nll': None}
    for epoch in range(start_epoch, args.epochs + 1):

        # if epoch%10 == 0:
        #   print("<***** STARTING EPOCH " + str(epoch) + " *****>")

        time_ep = time.time()

        lr = learning_rate_schedule(args.lr, epoch, args.epochs)
        utils.adjust_learning_rate(optimizer, lr)

        train_res = utils.train(loaders['train'], model, optimizer, criterion,
                                regularizer)
        if args.curve is None or not has_bn:
            test_res = utils.test(loaders['test'], model, criterion,
                                  regularizer)

        if epoch % args.save_freq == 0:
            utils.save_checkpoint(args.dir,
                                  epoch,
                                  model_state=model.state_dict(),
                                  optimizer_state=optimizer.state_dict())

        time_ep = time.time() - time_ep
        values = [
            epoch, lr, train_res['loss'], train_res['accuracy'],
            test_res['nll'], test_res['accuracy'], time_ep
        ]

        table = tabulate.tabulate([values],
                                  columns,
                                  tablefmt='simple',
                                  floatfmt='9.4f')
        if epoch % 40 == 1 or epoch == start_epoch:
            table = table.split('\n')
            table = '\n'.join([table[1]] + table)
        else:
            table = table.split('\n')[2]
        print(table)

    if args.epochs % args.save_freq != 0:
        utils.save_checkpoint(args.dir,
                              args.epochs,
                              model_state=model.state_dict(),
                              optimizer_state=optimizer.state_dict())
Esempio n. 6
0
                    default='Para128-256/checkpoint-180.pt',
                    metavar='CKPT',
                    help='checkpoint to eval (default: None)')

args = parser.parse_args()

os.makedirs(args.dir, exist_ok=True)

torch.backends.cudnn.benchmark = True

architecture = getattr(models, args.model)
curve = getattr(curves, args.curve)
model = curves.CurveNet(
    10,
    curve,
    architecture.curve,
    args.num_bends,
    architecture_kwargs=architecture.kwargs,
)

model.cuda()
checkpoint = torch.load(args.ckpt)
model.load_state_dict(checkpoint['model_state'])

spmodel = architecture.base(num_classes=10, **architecture.kwargs)

parameters = list(model.net.parameters())
sppara = list(spmodel.parameters())
#for i in range(0, len(sppara)):
#    ttt= i*3
#    weights = parameters[ttt:ttt + model.num_bends]
os.makedirs(args.dir, exist_ok=True)

d = args.dir
init_start = args.init_start  #'curves/curve50/checkpoint-100.pt'
init_middle = args.init_middle  #'curves/curve51/checkpoint-100.pt'
init_end = args.init_end  #'curves/curve52/checkpoint-100.pt'
num_classes = 10

architecture = getattr(models, args.model)
curve = getattr(curves, args.curve)

model = curves.CurveNet(
    10,
    curve,
    architecture.curve,
    3,
    True,
    True,
    architecture_kwargs=architecture.kwargs,
)

base_model = None
for path, k in [(init_start, 0), (init_middle, 1), (init_end, 2)]:
    if path is not None:
        if base_model is None:
            base_model = architecture.base(num_classes=num_classes,
                                           **architecture.kwargs)
        checkpoint = torch.load(path)
        print('Loading %s as point #%d' % (path, k))
        base_model.load_state_dict(checkpoint['model_state'])
        model.import_base_parameters(base_model, k)