コード例 #1
0
 def __init__(self, class_num, droprate=0.5, stride=2):
     super(ft_net44, self).__init__()
     self.add_module("module", resnet.resnet44())
     weights_ = torch.load("weights_cifar10/resnet44-014dd654.th")
     self.load_state_dict(weights_['state_dict'])
     self.module.linear = nn.Sequential()
     self.classifier = ClassBlock(64, class_num, droprate)
コード例 #2
0
    save_fold_name = [
        args.model,
        str(args.depth), args.dataset,
        'BS%d' % args.batch_size
    ]
    if args.origin:
        save_fold_name.insert(0, 'Origin')

    if args.model == 'resnet':
        if args.depth == 20:
            network = resnet.resnet20()
        if args.depth == 32:
            network = resnet.resnet32()
        if args.depth == 44:
            network = resnet.resnet44()
        if args.depth == 56:
            network = resnet.resnet56()
        if args.depth == 110:
            network = resnet.resnet110()

    if not args.origin:
        print('Pruning the model in %s' % args.pruned_model_dir)
        check_point = torch.load(args.pruned_model_dir + "model_best.pth.tar")
        network.load_state_dict(check_point['state_dict'])
        codebook_index_list = np.load(args.pruned_model_dir + "codebook.npy",
                                      allow_pickle=True).tolist()
        m_l = []
        b_l = []

        for i in network.modules():
コード例 #3
0
	# p      = Replica            ("-")
	if dset == "SimpleModel_CIFAR10":
		loader = CIFAR10Loader      (batch_size, p.getSpeeds(), p.getBatches())
		model  = SimpleCIFAR10Model ()
		num_epochs = 10

	elif dset == "RS_SimpleModel_CIFAR10":
		loader = CIFAR10ResnetLoader(batch_size, p.getSpeeds(), p.getBatches())
		import resnet
		
		if int(sys.argv[7]) == 20:
			model  = resnet.resnet20()
		if int(sys.argv[7]) == 32:
			model  = resnet.resnet32()
		if int(sys.argv[7]) == 44:
			model  = resnet.resnet44()
		if int(sys.argv[7]) == 56:
			model  = resnet.resnet56()
		if int(sys.argv[7]) == 110:
			model  = resnet.resnet110()

		num_epochs = 7

	elif dset == "MNIST":
		loader = MNISTLoader      (batch_size, p.getSpeeds(), p.getBatches())
		model  = SimpleMNISTModel ()
		num_epochs = 10
	else:
		print("DATASET NOT FOUND")

	p.setData  (loader)
コード例 #4
0
ファイル: train.py プロジェクト: RoseTele/pytorch_resnet
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batchSz', type=int, default=64)
    parser.add_argument('--nEpochs', type=int, default=300)
    parser.add_argument('--no-cuda', action='store_true')
    parser.add_argument('--net')
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--opt',
                        type=str,
                        default='sgd',
                        choices=('sgd', 'adam', 'rmsprop'))
    parser.add_argument('--gpu_id', type=str, default='0')

    args = parser.parse_args()

    args.cuda = not args.no_cuda and torch.cuda.is_available()
    args.save = 'work/' + args.net

    setproctitle.setproctitle(args.save)

    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    if os.path.exists(args.save):
        shutil.rmtree(args.save)
    os.makedirs(args.save)

    normMean = [0.49139968, 0.48215827, 0.44653124]
    normStd = [0.24703233, 0.24348505, 0.26158768]
    normTransform = transforms.Normalize(normMean, normStd)

    trainTransform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), normTransform
    ])
    testTransform = transforms.Compose([transforms.ToTensor(), normTransform])

    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    trainLoader = DataLoader(dset.CIFAR10(root='cifar',
                                          train=True,
                                          download=True,
                                          transform=trainTransform),
                             batch_size=args.batchSz,
                             shuffle=True,
                             **kwargs)
    testLoader = DataLoader(dset.CIFAR10(root='cifar',
                                         train=False,
                                         download=True,
                                         transform=testTransform),
                            batch_size=args.batchSz,
                            shuffle=False,
                            **kwargs)

    n_classes = 10
    if args.net == 'resnet20':
        net = resnet.resnet20(num_classes=n_classes)
    elif args.net == 'resnet32':
        net = resnet.resnet32(num_classes=n_classes)
    elif args.net == 'resnet44':
        net = resnet.resnet44(num_classes=n_classes)
    elif args.net == 'resnet56':
        net = resnet.resnet56(num_classes=n_classes)
    elif args.net == 'resnet110':
        net = resnet.resnet110(num_classes=n_classes)
    elif args.net == 'resnetxt29':
        net = resnetxt.resnetxt29(num_classes=n_classes)
    elif args.net == 'deform_resnet32':
        net = deformconvnet.deform_resnet32(num_classes=n_classes)
    else:
        net = densenet.DenseNet(growthRate=12,
                                depth=100,
                                reduction=0.5,
                                bottleneck=True,
                                nClasses=n_classes)

    print('  + Number of params: {}'.format(
        sum([p.data.nelement() for p in net.parameters()])))
    if args.cuda:
        net = net.cuda()
        gpu_id = args.gpu_id
        gpu_list = gpu_id.split(',')
        gpus = [int(i) for i in gpu_list]
        net = nn.DataParallel(net, device_ids=gpus)

    if args.opt == 'sgd':
        optimizer = optim.SGD(net.parameters(),
                              lr=1e-1,
                              momentum=0.9,
                              weight_decay=1e-4)
    elif args.opt == 'adam':
        optimizer = optim.Adam(net.parameters(), weight_decay=1e-4)
    elif args.opt == 'rmsprop':
        optimizer = optim.RMSprop(net.parameters(), weight_decay=1e-4)

    trainF = open(os.path.join(args.save, 'train.csv'), 'w')
    testF = open(os.path.join(args.save, 'test.csv'), 'w')

    for epoch in range(1, args.nEpochs + 1):
        adjust_opt(args.opt, optimizer, epoch)
        train(args, epoch, net, trainLoader, optimizer, trainF)
        test(args, epoch, net, testLoader, optimizer, testF)
        torch.save(net, os.path.join(args.save, 'latest.pth'))
        os.system('python plot.py {} &'.format(args.save))

    trainF.close()
    testF.close()