def getNetwork(args):
    if (args.net_type == 'alexnet'):
        net = AlexNet(num_classes, inputs)
        file_name = 'alexnet-'
    elif (args.net_type == 'resnet'):
        net = ResNet(args.depth, num_classes, inputs)
        file_name = 'resnet-' + str(args.depth)
    else:
        print(
            'Error : Network should be either [LeNet / AlexNet /SqueezeNet/ ResNet / Wide_ResNet'
        )
        sys.exit(0)

    return net, file_name
Example #2
0
def getNetwork(args):
    if (args.net_type == 'alexnet'):
        #net = models.alexnet(pretrained=True)
        #net.classifier[6] = nn.Linear(4096,num_classes)
        net = AlexNet(num_classes, inputs)
        file_name = 'alexnet-'

    elif (args.net_type == 'resnet'):
        #net = models.resnet18(pretrained=True)
        #net.fc = nn.Linear(512, num_classes)
        net = ResNet(args.depth, num_classes, inputs)
        file_name = 'resnet-18'
    else:
        print('Error : Network should be either [AlexNet / ResNet ')
        sys.exit(0)

    return net, file_name
Example #3
0
def load_model():
    global model
    model = ResNet(34, 2, 3)
    #model = resnet50(pretrained=False)
    model_path = "./checkpoint/business_cards/resnet-34.t7"
    checkpoint = torch.load(model_path, map_location='cpu')
    #model.load_state_dict(checkpoint)
    model = checkpoint['net']
    if use_cuda:
        model.cuda()
        model = torch.nn.DataParallel(model,
                                      device_ids=range(
                                          torch.cuda.device_count()))
        cudnn.benchmark = True
    model.eval()
def main():
    os.chdir(os.path.dirname(__file__))
    args = get_arguments()
    constr_weight = get_constraint(args.weight_bits, 'weight')
    constr_activation = get_constraint(args.activation_bits, 'activation')
    if args.dataset == 'cifar10':
        network = resnet20
        dataloader = dataloader_cifar10
    elif args.dataset == 'cifar100':
        t_net = ResNet(depth=56, num_classes=100)
        state = torch.load("/prj/neo_lv/user/ybhalgat/LSQ-KD/cifar100_pretrained/resnet56.pth.tar")
        t_net.load_state_dict(state)
        network = resnet20
        dataloader = dataloader_cifar100
    else:
        if args.network == 'resnet18':
            network = resnet18
        elif args.network == 'resnet50':
            network = resnet50
        elif args.network == 'efficientnet-b0':
            t_net = EfficientNet.from_pretrained("efficientnet-b3")
            network = efficientnet_b0
        else:
            print('Not Support Network Type: %s' % args.network)
            return
        dataloader = dataloader_imagenet
    train_loader = dataloader(args.data_root, split='train', batch_size=args.batch_size)
    test_loader = dataloader(args.data_root, split='test', batch_size=args.batch_size)
    net = network(quan_first=args.quan_first,
                  quan_last=args.quan_last,
                  constr_activation=constr_activation,
                  preactivation=args.preactivation,
                  bw_act=args.activation_bits)

    model_path = os.path.join(args.model_root, args.model_name + '.pth.tar')
    if not os.path.exists(model_path):
        model_path = model_path[:-4]
    name_weights_old = torch.load(model_path)
    name_weights_new = net.state_dict()
    name_weights_new.update(name_weights_old)
    load_checkpoint(net, name_weights_new)
    # net.load_state_dict(name_weights_new, strict=False)
    if not args.haq:
        add_lsqmodule(net, bit_width=args.weight_bits)
    else:
        if args.network == 'resnet50':
            strategy = [6, 6, 5, 5, 5, 5, 4, 5, 5, 4, 5, 5, 5, 5, 5, 5, 3, 5, 4, 3, 5, 4, 3, 4, 4, 4, 2, 5,
                        4, 3, 3, 5, 3, 2, 5, 3, 2, 4, 3, 2, 5, 3, 2, 5, 3, 4, 2, 5, 2, 3, 4, 2, 3, 4]
        elif args.network == 'efficientnet-b0':
            strategy = [7, 8, 8, 8, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 6, 6, 6,
                        6, 6, 6, 6, 6, 7, 6, 7, 6, 7, 6, 5, 6, 5, 6, 4, 5, 6, 5, 6, 4, 4, 5, 4, 5, 2,
                        3, 4, 3, 4, 2, 3, 4, 4, 7, 5, 2, 4, 2, 5, 5, 2, 4, 2, 5, 5, 2, 4, 2, 5, 5, 2,
                        4, 3, 3, 2]
        add_lsqmodule(net, strategy=strategy)

    print(net)
    net = net.cuda()
    net = nn.DataParallel(net, device_ids=range(cuda.device_count()))

    t_net = t_net.cuda()
    t_net = nn.DataParallel(t_net, device_ids=range(cuda.device_count()))



    quan_activation = isinstance(constr_activation, np.ndarray)
    postfix = '_w' if not quan_activation else '_a'
    new_model_name = args.prefix + args.model_name + '_lsq' + postfix
    cache_root = os.path.join('.', 'cache')
    train_loger = LogHelper(new_model_name, cache_root, quan_activation, args.resume)
    optimizer, lr_scheduler, optimizer_t = get_optimizer(s_net=net,
                                            t_net=t_net,
                                            optimizer=args.optimizer,
                                            lr_base=args.learning_rate,
                                            weight_decay=args.weight_decay,
                                            lr_scheduler=args.lr_scheduler,
                                            total_epoch=args.total_epoch,
                                            quan_activation=quan_activation,
                                            act_lr_factor=args.act_lr_factor,
                                            weight_lr_factor=args.weight_lr_factor)
    trainer = Trainer(net=net,
                      t_net=t_net,
                      train_loader=train_loader,
                      test_loader=test_loader,
                      optimizer=optimizer,
                      optimizer_t=optimizer_t,
                      lr_scheduler=lr_scheduler,
                      model_name=new_model_name,
                      train_loger=train_loger)
    trainer(total_epoch=args.total_epoch,
            save_check_point=True,
            resume=args.resume)
def main():

    ## Epochs, lr, Dataset={"FashionMNIST","CIFAR10"}

    args = {'epochs': 30, 'lr': 0.05, 'ensemble': 5, 'dataset': "FashionMNIST"}
    loss_fn = F.nll_loss

    #Selecting Main Dataset
    #FashionMNIST-Mnist
    #CIFAR10-SVHN
    ds = all_datasets[args['dataset']]()
    input_size, num_classes, train_dataset, test_dataset = ds
    kwargs = {"num_workers": 4, "pin_memory": True}

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=128,
                                               shuffle=True,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=5000,
                                              shuffle=False,
                                              **kwargs)

    #Selecting model CNN for FashionMNIST and Resnet for CIFAR10

    if args['dataset'] == "FashionMNIST":
        milestones = [10, 20]
        ensemble = [
            CNN(input_size, num_classes).cuda()
            for _ in range(args['ensemble'])
        ]
    else:
        milestones = [25, 50]
        ensemble = [
            ResNet(input_size, num_classes).cuda()
            for _ in range(args['ensemble'])
        ]

    ensemble = torch.nn.ModuleList(ensemble)
    #ensemble.load_state_dict(torch.load("FM_5_ensemble_30.pt"))

    optimizers = []
    schedulers = []

    for model in ensemble:
        # Need different optimisers to apply weight decay and momentum properly
        # when only optimising one element of the ensemble
        optimizers.append(
            torch.optim.SGD(model.parameters(),
                            lr=args['lr'],
                            momentum=0.9,
                            weight_decay=5e-4))

        schedulers.append(
            torch.optim.lr_scheduler.MultiStepLR(optimizers[-1],
                                                 milestones=milestones,
                                                 gamma=0.1))

    for epoch in range(1, args['epochs'] + 1):
        #####Train#####
        for i, model in enumerate(ensemble):
            train(model, train_loader, optimizers[i], epoch, loss_fn)
            schedulers[i].step()

        #####Test######
        #Test on testset of main dataset
        test(ensemble, test_loader, loss_fn)

        #####AUROC######
        #AUROC on Main + ood
        if (args['dataset'] == "FashionMNIST"):
            accuracy, auroc = get_fm_mnist_ood_ensemble(ensemble)
            print({'mnist_ood_auroc': auroc})
        else:
            accuracy, auroc = get_cifar10_svhn_ood_ensemble(ensemble)
            print({'cifar10_ood_auroc': auroc})

    #Save
    path = f"model{args['dataset']}_{len(ensemble)}"
    torch.save(ensemble.state_dict(), path + "_ensemble.pt")
Example #6
0
                                    class_mode="categorical",
                                    target_size=(64, 64),
                                    color_mode="rgb",
                                    shuffle=False,
                                    batch_size=BS)

# initialise the testing generator
testGen = valAug.flow_from_directory(config.TEST_PATH,
                                     class_mode="categorical",
                                     target_size=(64, 64),
                                     color_mode="rgb",
                                     shuffle=False,
                                     batch_size=BS)

# initialise our ResNet model and compile it
model = ResNet.build(64, 64, 3, 2, (3, 4, 6), (64, 128, 256, 512), reg=0.0005)
opt = SGD(lr=INIT_LR, momentum=0.9)
model.compile(loss="binary_crossentropy", optimizer=opt, metrics=["accuracy"])

# define our set of callbacks and fit the model
callbacks = [LearningRateScheduler(poly_decay)]

H = model.fit_generator(trainGen,
                        steps_per_epoch=totalTrain // BS,
                        validation_data=valGen,
                        validation_steps=totalVal // BS,
                        epochs=NUM_EPOCHS,
                        callbacks=callbacks)

# reset the testing generator and use the trained model to
# make predictions on the data
Example #7
0
def loadModel(config):
    net = ResNet(config)
    net.load_state_dict(torch.load('results/checkpoint.pt'))
    return net
Example #8
0
from utils.resnet import ResNet
from utils.dataTool import train_val_test_loader
from utils.trainer import train_test_model
from utils.baseline import testLastHourRegression, testArima
from config import Config
import torch

# 全局设置
torch.set_default_tensor_type(torch.DoubleTensor)
config = Config()
# 读入训练集,验证集
train_loader, test_loader = train_val_test_loader(config)
# 定义网络,训练并测试
net = ResNet(config).to(config.device)
train_test_model(net, train_loader, test_loader, config)
# 测试baseline
testLastHourRegression(test_loader, config.device, config.metric, hour_type='last_day')  # 前一天
testLastHourRegression(test_loader, config.device, config.metric, hour_type='last_week')  # 前一周
testArima(config)  # arima