예제 #1
0
def load(args):
    add_args(args, sub_args)
    net = VGG(10)
    net.build_gate(VIB)
    train_loader, test_loader = get_CIFAR10(args.batch_size)

    base_params = []
    gate_params = []
    for name, param in net.named_parameters():
        if 'gate' in name:
            gate_params.append(param)
        else:
            base_params.append(param)
    optimizer = optim.Adam([{
        'params': gate_params,
        'lr': 1e-2
    }, {
        'params': base_params,
        'lr': 1e-3,
        'weight_decay': 1e-4
    }])
    scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer,
        milestones=[int(r * args.num_epochs) for r in [.5, .8]],
        gamma=0.1)

    return net, train_loader, test_loader, optimizer, scheduler
예제 #2
0
def load(args):
    add_args(args, sub_args)
    net = VGG(10)
    train_loader, test_loader = get_CIFAR10(args.batch_size)
    optimizer = optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer,
        milestones=[int(r * args.num_epochs) for r in [0.5, 0.8]],
        gamma=0.1)
    return net, train_loader, test_loader, optimizer, scheduler
예제 #3
0
def load(args):
    add_args(args, sub_args)
    net = VGG(10)
    net.build_gate(L0Reg, {
        'weight_decay': 1e-4,
        'lamb': args.lamb,
        'droprate_init': 0.2
    })
    train_loader, test_loader = get_CIFAR10(args.batch_size)
    optimizer = optim.Adam(net.parameters(), lr=1e-3)
    scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer,
        milestones=[int(r * args.num_epochs) for r in [.5, .8]],
        gamma=0.1)

    return net, train_loader, test_loader, optimizer, scheduler