Пример #1
0
def main():

    # init conv net

    net = vgg.vgg11()
    if os.path.exists("./vgg.pkl"):
        net.load_state_dict(torch.load("./vgg.pkl"))
        print("load model")
    net.cuda()
    net.eval()

    print("load ok")

    while True:
        pull_screenshot(
            "autojump.png")  # obtain screen and save it to autojump.png
        image = Image.open('./autojump.png')
        set_button_position(image)
        image = preprocess(image)

        image = Variable(image.unsqueeze(0)).cuda()
        press_time = net(image).cpu().data[0].numpy()
        print(press_time)
        jump(press_time)

        time.sleep(random.uniform(1.5, 2))
Пример #2
0
    def init_net(self):

        net_args = {
            "pretrained": True,
            "n_input_channels": len(self.kwargs["static"]["imagery_bands"])
        }

        # https://pytorch.org/docs/stable/torchvision/models.html
        if self.kwargs["net"] == "resnet18":
            self.model = resnet.resnet18(**net_args)
        elif self.kwargs["net"] == "resnet34":
            self.model = resnet.resnet34(**net_args)
        elif self.kwargs["net"] == "resnet50":
            self.model = resnet.resnet50(**net_args)
        elif self.kwargs["net"] == "resnet101":
            self.model = resnet.resnet101(**net_args)
        elif self.kwargs["net"] == "resnet152":
            self.model = resnet.resnet152(**net_args)
        elif self.kwargs["net"] == "vgg11":
            self.model = vgg.vgg11(**net_args)
        elif self.kwargs["net"] == "vgg11_bn":
            self.model = vgg.vgg11_bn(**net_args)
        elif self.kwargs["net"] == "vgg13":
            self.model = vgg.vgg13(**net_args)
        elif self.kwargs["net"] == "vgg13_bn":
            self.model = vgg.vgg13_bn(**net_args)
        elif self.kwargs["net"] == "vgg16":
            self.model = vgg.vgg16(**net_args)
        elif self.kwargs["net"] == "vgg16_bn":
            self.model = vgg.vgg16_bn(**net_args)
        elif self.kwargs["net"] == "vgg19":
            self.model = vgg.vgg19(**net_args)
        elif self.kwargs["net"] == "vgg19_bn":
            self.model = vgg.vgg19_bn(**net_args)

        else:
            raise ValueError("Invalid network specified: {}".format(
                self.kwargs["net"]))

        #  run type: 1 = fine tune, 2 = fixed feature extractor
        #  - replace run type option with "# of layers to fine tune"
        if self.kwargs["run_type"] == 2:
            layer_count = len(list(self.model.parameters()))
            for layer, param in enumerate(self.model.parameters()):
                if layer <= layer_count - 5:
                    param.requires_grad = False

        # Parameters of newly constructed modules have requires_grad=True by default
        # get existing number for input features
        # set new number for output features to number of categories being classified
        # see: https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
        if "resnet" in self.kwargs["net"]:
            num_ftrs = self.model.fc.in_features
            self.model.fc = nn.Linear(num_ftrs, self.ncats)
        elif "vgg" in self.kwargs["net"]:
            num_ftrs = self.model.classifier[6].in_features
            self.model.classifier[6] = nn.Linear(num_ftrs, self.ncats)
Пример #3
0
def prune_vgg(model, pruning_strategy, cuda=True, dataparallel=True):
    cfg, cfg_mask = _calculate_channel_mask(model, pruning_strategy, cuda=cuda)
    pruned_model = vgg11(config=cfg)
    if cuda:
        pruned_model.cuda()
    if dataparallel:
        pruned_model.features = torch.nn.DataParallel(pruned_model.features)
    assign_model(model, pruned_model, cfg_mask)

    return pruned_model, cfg
    pass
Пример #4
0
def load_model(args):
    if args.model == 'vgg11':
        model = vgg.vgg11().to(device)
    if args.model == 'vgg13':
        model = vgg.vgg13().to(device)
    if args.model == 'vgg16':
        model = vgg.vgg16().to(device)
    elif args.model == 'vgg19':
        model = vgg.vgg19().to(device)
    elif args.model == 'modified_vgg11':
        model = modified_vgg.vgg11().to(device)
    elif args.model == 'modified_vgg13':
        model = modified_vgg.vgg13().to(device)
    elif args.model == 'modified_vgg16':
        model = modified_vgg.vgg16().to(device)
    elif args.model == 'modified_vgg19':
        model = modified_vgg.vgg19().to(device)
    return model
def main():

	# init conv net
	print("init net")
	net = vgg.vgg11()
	if os.path.exists("./model.pkl"):
		net.load_state_dict(torch.load("./model.pkl"))
		print("load model")

	net.cuda()

	# init dataset
	print("init dataset")
	data_loader = dataset.jump_data_loader()

	# init optimizer
	optimizer = torch.optim.Adam(net.parameters(),lr=0.0001)
	criterion = nn.MSELoss()

	# train
	print("training...")
	for epoch in range(1000):
		for i, (images, press_times) in enumerate(data_loader):
			images = Variable(images).cuda()
			press_times = Variable(press_times.float()).cuda()

			predict_press_times = net(images)

			loss = criterion(predict_press_times,press_times)

			optimizer.zero_grad()
			loss.backward()
			optimizer.step()

			if (i+1) % 10 == 0:
				print("epoch:",epoch,"step:",i,"loss:",loss.data[0])
			if (epoch+1) % 5 == 0 and i == 0:
				torch.save(net.state_dict(),"./vgg.pkl")
				print("save model")
Пример #6
0
Файл: main.py Проект: L0SG/WPI_2
import tensorflow as tf
import numpy as np
from vgg import vgg11
import cifar100_utils

(train_x, train_y), (test_x, test_y) = cifar100_utils.load_data()

# if trained the model before, load the weights
weights = None

with tf.Session() as sess:
    # instantiate the model
    vgg_model = vgg11(weights=None, sess=sess)
    # train the model
    vgg_model.train(images=train_x,
                    labels=train_y,
                    epochs=100,
                    val_split=0.1,
                    save_weights=True)
    # predict labels for the test images
    preds = vgg_model.predict(images=None)

    # calculate accuracy
    accuracy = np.sum([preds[i] == test_y[i] for i in xrange(test_y.shape[0])])
    accuracy /= test_y.shape[0]
    print('test accuracy: ' + str(accuracy))
Пример #7
0
    DT_val = torch.utils.data.DataLoader(val_dataset,
                                         batch_size=batch_sizes,
                                         shuffle=True)

    def NeptuneLog():
        neptune.log_metric('batch_size', batch_sizes)
        neptune.log_metric('learning_rate', learning_rate)
        neptune.log_text('pre-trained', str(pretrain_check))
        neptune.log_text('model', model_name)
        neptune.log_text('date_time', date_time)

    neptune.create_experiment(model_name)
    NeptuneLog()

    if model_name == 'vgg11':
        model = vgg.vgg11(pretrained=pretrain_check)
    elif model_name == 'vgg11_bn':
        model = vgg.vgg11_bn(pretrained=pretrain_check)
    elif model_name == 'vgg13':
        model = vgg.vgg13(pretrained=pretrain_check)
    elif model_name == 'vgg13_bn':
        model = vgg.vgg13_bn(pretrained=pretrain_check)
    elif model_name == 'vgg16':
        model = vgg.vgg16(pretrained=pretrain_check)
    elif model_name == 'vgg16_bn':
        model = vgg.vgg16_bn(pretrained=pretrain_check)
    elif model_name == 'vgg19':
        model = vgg.vgg19(pretrained=pretrain_check)
    elif model_name == 'vgg19_bn':
        model = vgg.vgg19_bn(pretrained=pretrain_check)
    model.eval()
Пример #8
0
    ])  # 数据处理归一化

    dataset_cifar10 = CIFAR10(root=Config.DATASETS_ROOT,
                              train=True,
                              transform=data_transform,
                              download=True)  # CIFAR10数据集
    train_sampler, valid_sampler = dataset_split(dataset_cifar10,
                                                 shuffle=True)  # 采样器

    train_data = DataLoader(
        dataset=dataset_cifar10,
        batch_size=Config.TRAIN_BATCH_SIZE,
        sampler=train_sampler)  # train数据。sampler不能和shuffle同时使用

    valid_data = DataLoader(
        dataset=dataset_cifar10,
        batch_size=Config.TRAIN_BATCH_SIZE,
        sampler=valid_sampler)  # valid数据。sampler不能和shuffle同时使用

    vgg = vgg.vgg11(10)
    print(vgg)  # 打印看看模型

    # optimizer = torch.optim.Adam(vgg.parameters())
    optimizer = torch.optim.SGD(vgg.parameters(), lr=Config.LEARN_RATE)
    loss_func = nn.CrossEntropyLoss()

    train(vgg, Config.EPOCHS, optimizer, loss_func, train_data, valid_data,
          'vgg11')

    pass
Пример #9
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    print(args)

    args.distributed = args.world_size > 1

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    model = vgg11()

    # if not args.distributed:
    #     if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
    #         model.features = torch.nn.DataParallel(model.features)
    #         model.cuda()
    #     else:
    #         model = torch.nn.DataParallel(model).cuda()
    # else:
    #     model.cuda()
    #     model = torch.nn.parallel.DistributedDataParallel(model)
    model = torch.nn.DataParallel(model, device_ids=[0]).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    # traindir = '/data5/ILSVRC/Data/CLS-LOC/train'
    # valdir = '/data5/ILSVRC/Data/CLS-LOC/val'
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    # if args.distributed:
    #     train_sampler = torch.ut.data.distributed.DistributedSampler(train_dataset)
    # else:
    train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch)

        file_train = open(
            "/home/hg31/work/Torch_SI/csv/ImageNet_" + str(5) +
            ".train_rndCropFlip.SI.train" + ".csv", "a")
        file_test = open(
            "/home/hg31/work/Torch_SI/csv/ImageNet_" + str(5) +
            ".test.SI.test" + ".csv", "a")
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, file_train,
              args.s)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, file_test)

        file_train.close()
        file_test.close()
Пример #10
0
def main_test(args):
    start_time = time.time()
    now = datetime.datetime.now().strftime('%Y-%m-%d-%H%M%S')
    # define paths

    logger = SummaryWriter('../logs')

    # easydict 사용하는 경우 주석처리
    # args = args_parser()

    # checkpoint 생성위치
    args.save_path = os.path.join(args.save_path, args.exp_folder)
    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    save_path_tmp = os.path.join(args.save_path, 'tmp_{}'.format(now))
    if not os.path.exists(save_path_tmp):
        os.makedirs(save_path_tmp)
    SAVE_PATH = os.path.join(args.save_path, '{}_{}_T[{}]_C[{}]_iid[{}]_E[{}]_B[{}]'.
                             format(args.dataset, args.model, args.epochs, args.frac, args.iid,
                                    args.local_ep, args.local_bs))

    # 시드 고정
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)



#    torch.cuda.set_device(0)
    device = torch.device("cuda:{}".format(args.gpu) if torch.cuda.is_available() else "cpu")
    cpu_device = torch.device('cpu')
    # log 파일 생성
    log_path = os.path.join('../logs', args.exp_folder)
    if not os.path.exists(log_path):
        os.makedirs(log_path)

    loggertxt = get_logger(
        os.path.join(log_path, '{}_{}_{}_{}.log'.format(args.model, args.optimizer, args.norm, now)))
    logging.info(args)
    # csv
    csv_save = '../csv/' + now
    csv_path = os.path.join(csv_save, 'accuracy.csv')
    csv_logger_keys = ['train_loss', 'accuracy']
    csvlogger = CSVLogger(csv_path, csv_logger_keys)

    # load dataset and user groups
    train_dataset, test_dataset, client_loader_dict = get_dataset(args)

    # cifar-100의 경우 자동 설정
    if args.dataset == 'cifar100':
        args.num_classes = 100
    # BUILD MODEL
    if args.model == 'cnn':
        # Convolutional neural network
        if args.dataset == 'mnist':
            global_model = CNNMnist(args=args)
        elif args.dataset == 'fmnist':
            global_model = CNNFashion_Mnist(args=args)
        elif args.dataset == 'cifar':
            global_model = CNNCifar(args=args)
        elif args.dataset == 'cifar100':
            global_model = CNNCifar(args=args)

    elif args.model == 'mlp':
        # Multi-layer preceptron
        img_size = train_dataset[0][0].shape
        len_in = 1
        for x in img_size:
            len_in *= x
            global_model = MLP(dim_in=len_in, dim_hidden=64,
                               dim_out=args.num_classes)
    elif args.model == 'cnn_vc':
        global_model = CNNCifar_fedVC(args=args)
    elif args.model == 'cnn_vcbn':
        global_model = CNNCifar_VCBN(args=args)
    elif args.model == 'cnn_vcgn':
        global_model = CNNCifar_VCGN(args=args)
    elif args.model == 'resnet18_ws':
        global_model = resnet18(num_classes=args.num_classes, weight_stand=1)
    elif args.model == 'resnet18':
        global_model = resnet18(num_classes=args.num_classes, weight_stand=0)
    elif args.model == 'resnet32':
        global_model = ResNet32_test(num_classes=args.num_classes)
    elif args.model == 'resnet18_mabn':
        global_model = resnet18_mabn(num_classes=args.num_classes)
    elif args.model == 'vgg':
        global_model = vgg11()
    elif args.model == 'cnn_ws':
        global_model = CNNCifar_WS(args=args)


    else:
        exit('Error: unrecognized model')

    # Set the model to train and send it to device.
    loggertxt.info(global_model)
    # fedBN처럼 gn no communication 용
    client_models = [copy.deepcopy(global_model) for idx in range(args.num_users)]

    # copy weights
    global_weights = global_model.state_dict()

    global_model.to(device)
    global_model.train()

    # Training
    train_loss, train_accuracy = [], []
    val_acc_list, net_list = [], []


    # how does help BN 확인용
    client_loss = [[] for i in range(args.num_users)]
    client_conv_grad = [[] for i in range(args.num_users)]
    client_fc_grad = [[] for i in range(args.num_users)]
    client_total_grad_norm = [[] for i in range(args.num_users)]
    # 전체 loss 추적용 -how does help BN

    # 재시작
    if args.resume:
        checkpoint = torch.load(SAVE_PATH)
        global_model.load_state_dict(checkpoint['global_model'])
        if args.hold_normalize:
            for client_idx in range(args.num_users):
                client_models[client_idx].load_state_dict(checkpoint['model_{}'.format(client_idx)])
        else:
            for client_idx in range(args.num_users):
                client_models[client_idx].load_state_dict(checkpoint['global_model'])
        resume_iter = int(checkpoint['a_iter']) + 1
        print('Resume trainig form epoch {}'.format(resume_iter))
    else:
        resume_iter = 0


    # learning rate scheduler
    #scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, gamma=0.1,step_size=500)

    # start training
    for epoch in tqdm(range(args.epochs)):
        local_weights, local_losses = [], []
        if args.verbose:
            print(f'\n | Global Training Round : {epoch + 1} |\n')

        global_model.train()
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)


        for idx in idxs_users:
            """
            for key in global_model.state_dict().keys():
                if args.hold_normalize:
                    if 'bn' not in key:
                        client_models[idx].state_dict()[key].data.copy_(global_model.state_dict()[key])
                else:
                    client_models[idx].state_dict()[key].data.copy_(global_model.state_dict()[key])
            """
            torch.cuda.empty_cache()


            local_model = LocalUpdate(args=args, logger=logger, train_loader=client_loader_dict[idx], device=device)
            w, loss, batch_loss, conv_grad, fc_grad, total_gard_norm = local_model.update_weights(
                model=copy.deepcopy(global_model), global_round=epoch, idx_user=idx)
            local_weights.append(copy.deepcopy(w))
            # client의 1 epoch에서의 평균 loss값  ex)0.35(즉, batch loss들의 평균)
            local_losses.append(copy.deepcopy(loss))

            # 전체 round scheduler
          #  scheduler.step()
            # loss graph용 -> client당 loss값 진행 저장 -> 모두 client별로 저장.
            client_loss[idx].append(batch_loss)
            client_conv_grad[idx].append(conv_grad)
            client_fc_grad[idx].append(fc_grad)
            client_total_grad_norm[idx].append(total_gard_norm)

            # print(total_gard_norm)
            # gn, bn 복사
            # client_models[idx].load_state_dict(w)
            del local_model
            del w
        # update global weights
        global_weights = average_weights(local_weights, client_loader_dict, idxs_users)
        # update global weights
#        opt = OptRepo.name2cls('adam')(global_model.parameters(), lr=0.01, betas=(0.9, 0.99), eps=1e-3)
        opt = OptRepo.name2cls('sgd')(global_model.parameters(), lr=10, momentum=0.9)
        opt.zero_grad()
        opt_state = opt.state_dict()
        global_weights = aggregation(global_weights, global_model)
        global_model.load_state_dict(global_weights)
        opt = OptRepo.name2cls('sgd')(global_model.parameters(), lr=10, momentum=0.9)
#        opt = OptRepo.name2cls('adam')(global_model.parameters(), lr=0.01, betas=(0.9, 0.99), eps=1e-3)
        opt.load_state_dict(opt_state)
        opt.step()
        loss_avg = sum(local_losses) / len(local_losses)
        train_loss.append(loss_avg)

        global_model.eval()
        #        for c in range(args.num_users):
        #            local_model = LocalUpdate(args=args, dataset=train_dataset,
        #                                      idxs=user_groups[idx], logger=logger)
        #            acc, loss = local_model.inference(model=global_model)
        #            list_acc.append(acc)
        #            list_loss.append(loss)
        #        train_accuracy.append(sum(list_acc)/len(list_acc))
        train_accuracy = test_inference(args, global_model, test_dataset, device=device)
        val_acc_list.append(train_accuracy)
        # print global training loss after every 'i' rounds
        # if (epoch+1) % print_every == 0:
        loggertxt.info(f' \nAvg Training Stats after {epoch + 1} global rounds:')
        loggertxt.info(f'Training Loss : {loss_avg}')
        loggertxt.info('Train Accuracy: {:.2f}% \n'.format(100 * train_accuracy))
        csvlogger.write_row([loss_avg, 100 * train_accuracy])
        if (epoch + 1) % 100 == 0:
            tmp_save_path = os.path.join(save_path_tmp, 'tmp_{}.pt'.format(epoch+1))
            torch.save(global_model.state_dict(),tmp_save_path)
    # Test inference after completion of training
    test_acc = test_inference(args, global_model, test_dataset, device=device)

    print(' Saving checkpoints to {}...'.format(SAVE_PATH))
    if args.hold_normalize:
        client_dict = {}
        for idx, model in enumerate(client_models):
            client_dict['model_{}'.format(idx)] = model.state_dict()
        torch.save(client_dict, SAVE_PATH)
    else:
        torch.save({'global_model': global_model.state_dict()}, SAVE_PATH)

    loggertxt.info(f' \n Results after {args.epochs} global rounds of training:')
    # loggertxt.info("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
    loggertxt.info("|---- Test Accuracy: {:.2f}%".format(100 * test_acc))


    # frac이 1이 아닐경우 잘 작동하지않음.
    # batch_loss_list = np.array(client_loss).sum(axis=0) / args.num_users

    # conv_grad_list = np.array(client_conv_grad).sum(axis=0) / args.num_users
    # fc_grad_list = np.array(client_fc_grad).sum(axis=0) / args.num_users
    # total_grad_list = np.array(client_total_grad_norm).sum(axis=0) /args.num_users
    # client의 avg를 구하고 싶었으나 현재는 client 0만 확인
    # client마다 batch가 다를 경우 bug 예상
    return train_loss, val_acc_list, client_loss[0], client_conv_grad[0], client_fc_grad[0], client_total_grad_norm[0]
Пример #11
0
def prune_vgg(model, pruning_strategy, cuda=True, dataparallel=True):
    cfg, cfg_mask = _calculate_channel_mask(model, pruning_strategy, cuda=cuda)
    pruned_model = vgg11(config=cfg)
    if cuda:
        pruned_model.cuda()
    if dataparallel:
        pruned_model.features = torch.nn.DataParallel(pruned_model.features)
    assign_model(model, pruned_model, cfg_mask)

    return pruned_model, cfg
    pass


if __name__ == '__main__':
    test_model = vgg11().cuda()

    # fake polarization to test pruning
    for name, module in test_model.named_modules():
        if isinstance(module, nn.BatchNorm1d) or isinstance(
                module, nn.BatchNorm2d):
            module.weight.data.zero_()
            one_num = randint(3, 30)
            module.weight.data[:one_num] = 0.99

            print(f"{name} remains {one_num}")

    pruned_model, cfg = prune_vgg(test_model, pruning_strategy="fixed")

    demo_input = torch.rand(3, 3, 224, 224).cuda()
Пример #12
0
def main_worker(gpu, ngpus_per_node, args):
    global best_prec1
    args.gpu = gpu

    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)

    if args.rank == 0:
        if not os.path.exists(args.save):
            os.makedirs(args.save)
        if not os.path.exists(args.backup_path):
            os.makedirs(args.backup_path)

    if args.distributed:
        # For multiprocessing distributed training, rank needs to be the
        # global rank among all the processes
        args.rank = args.rank * ngpus_per_node + gpu
        print("Starting process rank {}".format(args.rank))
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)

    print("rank#{}: CUDA_VISIBLE_DEVICES: {}".format(args.rank, os.environ['CUDA_VISIBLE_DEVICES']))

    if args.arch == "vgg11":
        if args.resume and os.path.isfile(args.resume):
            checkpoint = torch.load(args.resume)
            if "cfg" in checkpoint:
                model = vgg11(config=checkpoint['cfg'])
            else:
                model = vgg11()
        else:
            model = vgg11()
    elif args.arch == "resnet50":
        model = resnet50(mask=False, bn_init_value=args.bn_init_value,
                         aux_fc=False, save_feature_map=False)
    else:
        raise NotImplementedError("model {} is not supported".format(args.arch))

    if not args.distributed:
        # DataParallel
        model.cuda()
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            # see discussion
            # https://discuss.pytorch.org/t/are-there-reasons-why-dataparallel-was-used-differently-on-alexnet-and-vgg-in-the-imagenet-example/19844
            model.features = torch.nn.DataParallel(model.features)
        else:
            model = torch.nn.DataParallel(model).cuda()
    else:
        # DistributedDataParallel
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(), args.lr[0],
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    if args.debug:
        # fake polarization to test pruning
        for name, module in model.named_modules():
            if isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.BatchNorm2d):
                module.weight.data.zero_()
                one_num = randint(3, 30)
                module.weight.data[:one_num] = 0.99

                print(f"{name} remains {one_num}")

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            raise ValueError("=> no checkpoint found at '{}'".format(args.resume))

    print("Model loading completed. Model Summary:")
    print(model)

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    print("rank #{}: loading the dataset...".format(args.rank))

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    print("rank #{}: dataloader loaded!".format(args.rank))

    if args.evaluate:
        validate(val_loader, model, criterion, epoch=0, args=args, writer=None)
        return

    # only master process in each node write to disk
    writer = SummaryWriter(logdir=args.save, write_to_disk=args.rank % ngpus_per_node == 0)

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        # the adjusting only work when epoch is at decay_epoch
        adjust_learning_rate(optimizer, epoch, lr=args.lr, decay_epoch=args.decay_epoch)

        # draw bn hist to tensorboard
        weights, bias = bn_weights(model)
        for bn_name, bn_weight in weights:
            writer.add_histogram("bn/" + bn_name, bn_weight, global_step=epoch)
        for bn_name, bn_bias in bias:
            writer.add_histogram("bn_bias/" + bn_name, bn_bias, global_step=epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch,
              args.lbd, args=args,
              is_debug=args.debug, writer=writer)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, epoch,
                         args=args, writer=writer)

        report_prune_result(model)  # do not really prune the model

        writer.add_scalar("train/lr", optimizer.param_groups[0]['lr'], epoch)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        if args.rank % ngpus_per_node == 0:
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.save,
                save_backup=(epoch - args.start_epoch) % 5 == 0,
                backup_path=args.backup_path,
                epoch=epoch)

        writer.flush()

    writer.close()
    print("Best prec@1: {}".format(best_prec1))
Пример #13
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    if len(args.lr) != len(args.decay_epoch) + 1:
        print("args.lr: {}".format(args.lr))
        print("args.decay-epoch: {}".format(args.decay_epoch))
        raise ValueError("inconsistent between lr-decay-gamma and decay-epoch")

    print(args)

    args.distributed = args.world_size > 1

    # reproducibility
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if args.rank == 0:
        if not os.path.exists(args.save):
            os.makedirs(args.save)
        if not os.path.exists(args.backup_path):
            os.makedirs(args.backup_path)

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    if args.refine:
        checkpoint = torch.load(args.refine)
        print("loaded checkpoint. cfg:")
        print(checkpoint['cfg'])
        if args.arch == "vgg11":
            model = vgg11(config=checkpoint['cfg'])
        elif args.arch == "resnet50":
            if args.expand:
                if "downsample_cfg" in checkpoint:
                    downsample_cfg = checkpoint["downsample_cfg"]
                else:
                    downsample_cfg = None
                model = ResNetExpand(cfg=checkpoint['cfg'],
                                     aux_fc=False,
                                     downsample_cfg=downsample_cfg)
            else:
                raise NotImplementedError("Use --expand option.")

        else:
            raise NotImplementedError("{} is not supported".format(args.arch))

        model.load_state_dict(checkpoint['state_dict'])

        # there is no parameters in ChannelMask layers
        # we need to load it manually
        if args.expand:
            bn3_masks = checkpoint["bn3_masks"]
            bottleneck_modules = list(
                filter(lambda m: isinstance(m[1], Bottleneck),
                       model.named_modules()))
            assert len(bn3_masks) == len(bottleneck_modules)
            for i, (name, m) in enumerate(bottleneck_modules):
                if isinstance(m, Bottleneck):
                    mask = bn3_masks[i]
                    assert mask[1].shape[0] == m.expand_layer.idx.shape[0]
                    m.expand_layer.idx = np.argwhere(
                        mask[1].clone().cpu().numpy()).squeeze()

        if 'downsample_cfg' in checkpoint:
            # set downsample expand layer
            downsample_modules = list(
                filter(
                    lambda m: isinstance(m[1], nn.Sequential) and 'downsample'
                    in m[0], model.named_modules()))
            downsample_mask = checkpoint['downsample_mask']
            assert len(downsample_modules) == len(downsample_mask)
            for i, (name, m) in enumerate(downsample_modules):
                mask = downsample_mask[f"{name}.1"]
                assert mask.shape[0] == m[-1].idx.shape[0]
                m[-1].idx = np.argwhere(mask.clone().cpu().numpy()).squeeze()

        if not args.distributed:
            if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
                model.features = torch.nn.DataParallel(model.features)
                model.cuda()
            else:
                model = torch.nn.DataParallel(model).cuda()
        else:
            model.cuda()
            model = torch.nn.parallel.DistributedDataParallel(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr[0],
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1'].item()

            for param_name, param in model.named_parameters():
                if param_name not in checkpoint['state_dict']:
                    checkpoint['state_dict'][param_name] = param.data
                    raise ValueError(
                        "Missing parameter {}, do not load!".format(
                            param_name))

            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])

            # move optimizer buffer to GPU
            for p in optimizer.state.keys():
                param_state = optimizer.state[p]
                buf = param_state["momentum_buffer"]
                param_state["momentum_buffer"] = buf.cuda(
                )  # move buf to device

            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            raise ValueError("=> no checkpoint found at '{}'".format(
                args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion, epoch=0, writer=None)
        return

    # only master process write to disk
    writer = SummaryWriter(logdir=args.save, write_to_disk=args.rank == 0)

    if args.arch == "resnet50":
        summary = pruning_summary_resnet50(model, args.expand)
    else:
        print("WARNING: arch {} do not support pretty print".format(args.arch))
        summary = str(model)

    print(model)

    print("********** MODEL SUMMARY **********")
    print(summary)
    print("********** ************* **********")

    writer.add_text("model/summary", summary)

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch)

        weights = bn_weights(model)
        for bn_name, bn_weight in weights:
            writer.add_histogram("bn/" + bn_name, bn_weight, global_step=epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, writer)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, epoch, writer)

        writer.add_scalar("train/lr", optimizer.param_groups[0]['lr'], epoch)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            args.save,
            save_backup=(epoch - args.start_epoch) % 5 == 0,
            backup_path=args.backup_path,
            epoch=epoch)

    writer.close()
    print("Best prec@1: {}".format(best_prec1))
Пример #14
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    print(args)

    args.distributed = args.world_size > 1

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    model = vgg11()

    if args.scratch:
        checkpoint = torch.load(args.scratch)
        model = vgg11(pretrained=False, config=checkpoint['cfg'])

    if not args.distributed:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
    else:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, step_size)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args.s)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.save)
Пример #15
0
                                            transform=test_transform)
        elif args.model == 'ResNet34':
            model = models.resnet34(pretrained=False)
            
            train_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])
            test_transform = train_transform
            train_dataset = datasets.ImageFolder(args.data_dir, train=True, download=False,
                                             transform=train_transform)
            test_dataset = datasets.ImageFolder(args.data_dir, train=False, download=False,
                                            transform=test_transform)
        elif args.model == 'VGG11':
            model = vgg.vgg11()

            train_transform, test_transform = get_data_transform('cifar')
            train_dataset = datasets.CIFAR100(args.data_dir, train=True, download=False,
                                             transform=train_transform)
            test_dataset = datasets.CIFAR100(args.data_dir, train=False, download=False,
                                            transform=test_transform)
        else:
            print('Model must be {} or {}!'.format('MnistCNN', 'AlexNet'))
            sys.exit(-1)
        models.append(model)
    train_bsz = args.train_bsz
    train_bsz /= len(workers)
    train_bsz = int(train_bsz)

    train_data = partition_dataset(train_dataset, workers)
Пример #16
0
def train_vgg_11(
        pretrained=False,
        dataset_name='imagenet',
        prune=False,
        prune_params='',
        learning_rate=conf.learning_rate,
        num_epochs=conf.num_epochs,
        batch_size=conf.batch_size,
        learning_rate_decay_factor=conf.learning_rate_decay_factor,
        weight_decay=conf.weight_decay,
        num_epochs_per_decay=conf.num_epochs_per_decay,
        checkpoint_step=conf.checkpoint_step,
        checkpoint_path=conf.root_path + 'vgg_11' + conf.checkpoint_path,
        highest_accuracy_path=conf.root_path + 'vgg_11' +
    conf.highest_accuracy_path,
        global_step_path=conf.root_path + 'vgg_11' + conf.global_step_path,
        default_image_size=224,
        momentum=conf.momentum,
        num_workers=conf.num_workers):
    if dataset_name is 'imagenet':
        train_set_size = conf.imagenet['train_set_size']
        mean = conf.imagenet['mean']
        std = conf.imagenet['std']
        train_set_path = conf.imagenet['train_set_path']
        validation_set_path = conf.imagenet['validation_set_path']

    # gpu or not
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print('using: ', end='')
    print(torch.cuda.get_device_name(torch.cuda.current_device()))

    #define the model
    net = vgg.vgg11(pretrained).to(device)

    #define loss function and optimizer
    criterion = nn.CrossEntropyLoss()  # 损失函数为交叉熵,多用于多分类问题
    optimizer = optim.SGD(
        net.parameters(),
        lr=learning_rate)  # 优化方式为mini-batch momentum-SGD,并采用L2正则化(权重衰减)

    highest_accuracy = 0
    if os.path.exists(highest_accuracy_path):
        f = open(highest_accuracy_path, 'r')
        highest_accuracy = float(f.read())
        f.close()
        print('highest accuracy from previous training is %f' %
              highest_accuracy)

    global_step = 0
    if os.path.exists(global_step_path):
        f = open(global_step_path, 'r')
        global_step = int(f.read())
        f.close()
        print('global_step at present is %d' % global_step)
        model_saved_at = checkpoint_path + '/global_step=' + str(
            global_step) + '.pth'
        print('load model from' + model_saved_at)
        net.load_state_dict(torch.load(model_saved_at))

        # Data loading code
    transform = transforms.Compose([
        transforms.RandomResizedCrop(default_image_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])
    train = datasets.ImageFolder(train_set_path, transform)
    val = datasets.ImageFolder(validation_set_path, transform)
    train_loader = torch.utils.data.DataLoader(train,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=num_workers)
    validation_loader = torch.utils.data.DataLoader(val,
                                                    batch_size=batch_size,
                                                    shuffle=False,
                                                    num_workers=num_workers)

    print("{} Start training vgg-11...".format(datetime.now()))
    for epoch in range(num_epochs):
        print("{} Epoch number: {}".format(datetime.now(), epoch + 1))
        net.train()

        #one epoch for one loop
        for step, data in enumerate(train_loader, 0):
            # 准备数据
            length = len(train_loader)
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()

            # forward + backward
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            #decay learning rate
            global_step += 1
            decay_steps = int(train_set_size / batch_size *
                              num_epochs_per_decay)
            exponential_decay_learning_rate(optimizer, learning_rate,
                                            global_step, decay_steps,
                                            learning_rate_decay_factor)

            if step % checkpoint_step == 0 and step != 0:
                print("{} Start validation".format(datetime.now()))
                print("{} global step = {}".format(datetime.now(),
                                                   global_step))
                with torch.no_grad():
                    correct = 0
                    total = 0
                    for val_data in validation_loader:
                        net.eval()
                        images, labels = val_data
                        images, labels = images.to(device), labels.to(device)
                        outputs = net(images)
                        # 取得分最高的那个类 (outputs.data的索引号)
                        _, predicted = torch.max(outputs.data, 1)
                        total += labels.size(0)
                        correct += (predicted == labels).sum()
                    correct = float(correct.cpu().numpy().tolist())
                    accuracy = correct / total
                    print("{} Validation Accuracy = {:.4f}".format(
                        datetime.now(), accuracy))
                    if accuracy > highest_accuracy:
                        highest_accuracy = accuracy
                        #save model
                        print("{} Saving model...".format(datetime.now()))
                        torch.save(
                            net.state_dict(), '%s/global_step=%d.pth' %
                            (checkpoint_path, global_step))
                        print("{} Model saved ".format(datetime.now()))
                        #save highest accuracy
                        f = open(highest_accuracy_path, 'w')
                        f.write(str(highest_accuracy))
                        f.close()
                        #save global step
                        f = open(global_step_path, 'w')
                        f.write(str(global_step))
                        print("{} model saved at global step = {}".format(
                            datetime.now(), global_step))
                        f.close()
                        print('continue training')
    def __init__(self, base_0, base_1):
        super(StartOfLineFinder, self).__init__()

        self.cnn = vgg.vgg11()
        self.base_0 = base_0
        self.base_1 = base_1
Пример #18
0
def skyline_model_provider():
    return vgg.vgg11().cuda()
Пример #19
0
testloader = torch.utils.data.DataLoader(testset,
                                         batch_size=100,
                                         shuffle=False,
                                         num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
           'ship', 'truck')

# Model
print('==> Building model..')
'''
Critical - CHOICE OF ARCHITECTURE 
'''
#net = resnet.ResNet18()
#net = vgg.vgg16()
net = vgg.vgg11()
#if args.BN:
#    net = vgg.vgg16_bn()

#else:
#    net = vgg.vgg11()

# net = VGG('VGG19')
# net = vgg.vgg11_bn()
# net = resnet.ResNet18()
# net = PreActResNet18()
# net = GoogLeNet()
# net = DenseNet121()
# net = ResNeXt29_2x64d()
# net = MobileNet()
# net = MobileNetV2()
Пример #20
0
def get_model(args):
    network = args.network

    if network == 'vgg11':
        model = vgg.vgg11(num_classes=args.class_num)
    elif network == 'vgg13':
        model = vgg.vgg13(num_classes=args.class_num)
    elif network == 'vgg16':
        model = vgg.vgg16(num_classes=args.class_num)
    elif network == 'vgg19':
        model = vgg.vgg19(num_classes=args.class_num)
    elif network == 'vgg11_bn':
        model = vgg.vgg11_bn(num_classes=args.class_num)
    elif network == 'vgg13_bn':
        model = vgg.vgg13_bn(num_classes=args.class_num)
    elif network == 'vgg16_bn':
        model = vgg.vgg16_bn(num_classes=args.class_num)
    elif network == 'vgg19_bn':
        model = vgg.vgg19_bn(num_classes=args.class_num)
    elif network == 'resnet18':
        model = models.resnet18(num_classes=args.class_num)
        model.conv1 = torch.nn.Conv2d(in_channels=1,
                                      out_channels=model.conv1.out_channels,
                                      kernel_size=model.conv1.kernel_size,
                                      stride=model.conv1.stride,
                                      padding=model.conv1.padding,
                                      bias=model.conv1.bias)
    elif network == 'resnet34':
        model = models.resnet34(num_classes=args.class_num)
        model.conv1 = torch.nn.Conv2d(in_channels=1,
                                      out_channels=model.conv1.out_channels,
                                      kernel_size=model.conv1.kernel_size,
                                      stride=model.conv1.stride,
                                      padding=model.conv1.padding,
                                      bias=model.conv1.bias)
    elif network == 'resnet50':
        model = models.resnet50(num_classes=args.class_num)
        model.conv1 = torch.nn.Conv2d(in_channels=1,
                                      out_channels=model.conv1.out_channels,
                                      kernel_size=model.conv1.kernel_size,
                                      stride=model.conv1.stride,
                                      padding=model.conv1.padding,
                                      bias=model.conv1.bias)
    elif network == 'resnet101':
        model = models.resnet101(num_classes=args.class_num)
        model.conv1 = torch.nn.Conv2d(in_channels=1,
                                      out_channels=model.conv1.out_channels,
                                      kernel_size=model.conv1.kernel_size,
                                      stride=model.conv1.stride,
                                      padding=model.conv1.padding,
                                      bias=model.conv1.bias)
    elif network == 'resnet152':
        model = models.resnet152(num_classes=args.class_num)
        model.conv1 = torch.nn.Conv2d(in_channels=1,
                                      out_channels=model.conv1.out_channels,
                                      kernel_size=model.conv1.kernel_size,
                                      stride=model.conv1.stride,
                                      padding=model.conv1.padding,
                                      bias=model.conv1.bias)
    elif network == 'densenet121':
        model = densenet.densenet121(num_classes=args.class_num)
    elif network == 'densenet169':
        model = densenet.densenet169(num_classes=args.class_num)
    elif network == 'densenet161':
        model = densenet.densenet161(num_classes=args.class_num)
    elif network == 'densenet201':
        model = densenet.densenet201(num_classes=args.class_num)

    return model
Пример #21
0
                    type=str,
                    metavar='PATH',
                    help='path to save prune model (default: none)')
parser.add_argument('-j',
                    '--workers',
                    default=20,
                    type=int,
                    metavar='N',
                    help='number of data loading workers (default: 20)')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

if not os.path.exists(args.save):
    os.makedirs(args.save)

model = vgg11()
model.features = nn.DataParallel(model.features)
cudnn.benchmark = True

if args.cuda:
    model.cuda()

if args.model:
    if os.path.isfile(args.model):
        print("=> loading checkpoint '{}'".format(args.model))
        checkpoint = torch.load(args.model)
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}".format(
            args.model, checkpoint['epoch'], best_prec1))
Пример #22
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    if (args.arch != "mobilenetv2" or args.lr_strategy
            == 'step') and len(args.lr) != len(args.decay_epoch) + 1:
        # MobileNet v2 uses cosine learning rate schedule
        print("args.lr: {}".format(args.lr))
        print("args.decay-epoch: {}".format(args.decay_epoch))
        raise ValueError("inconsistent between lr-decay-gamma and decay-epoch")

    if args.width_multiplier != 1.0 and args.arch != "mobilenetv2":
        if args.arch == "resnet50":
            print(
                "For ResNet-50 with --width-multiplier, no need to specific --width-multiplier in finetuning."
            )
        raise ValueError(
            "--width-multiplier only works for MobileNet v2. \n"
            f"got --width-multiplier {args.width_multiplier} for --arch {args.arch}"
        )

    if args.arch == "mobilenetv2" and not args.lr_strategy == 'step':
        assert len(args.lr) == 1, "For MobileNet v2, learning rate only needs one value for" \
                                  "cosine learning rate schedule."
        print("WARNING: --decay-step is disabled.")

    if args.warmup and not args.scratch:
        raise ValueError("Finetuning should not use --warmup.")

    print(args)
    print(f"Current git hash: {common.get_git_id()}")

    args.distributed = args.world_size > 1

    # reproducibility
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if args.rank == 0:
        if not os.path.exists(args.save):
            os.makedirs(args.save)
        if not os.path.exists(args.backup_path):
            os.makedirs(args.backup_path)

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    if args.refine:
        checkpoint = torch.load(args.refine)

        print("loaded checkpoint. cfg:")
        print(checkpoint['cfg'])
        if args.arch == "vgg11":
            model = vgg11(config=checkpoint['cfg'])
        elif args.arch == "resnet50":
            if args.expand:
                if "downsample_cfg" in checkpoint:
                    downsample_cfg = checkpoint["downsample_cfg"]
                else:
                    downsample_cfg = None

                if args.expand:
                    # there is no parameters in ChannelMask layers
                    # we need to restore it manually
                    if 'expand_idx' not in checkpoint:
                        # compatible to resprune-expand
                        expand_idx = []
                        bn3_masks = checkpoint["bn3_masks"]
                        # bottleneck_modules = list(filter(lambda m: isinstance(m[1], Bottleneck), model.named_modules()))
                        # assert len(bn3_masks) == len(bottleneck_modules)
                        for mask in bn3_masks:
                            idx = np.argwhere(
                                mask[1].clone().cpu().numpy()).squeeze()
                            expand_idx.append(idx)
                    else:
                        # resprune-expand-gate pruning save expand_idx in checkpoint
                        expand_idx = checkpoint['expand_idx']
                else:
                    expand_idx = None

                model = ResNetExpand(cfg=checkpoint['cfg'],
                                     expand_idx=expand_idx,
                                     aux_fc=False,
                                     downsample_cfg=downsample_cfg,
                                     gate=False)
            else:
                raise NotImplementedError("Use --expand option.")
        elif args.arch == "mobilenetv2":
            if 'gate' in checkpoint and checkpoint['gate'] is True:
                input_mask = True
            else:
                input_mask = False
            model = mobilenet_v2(inverted_residual_setting=checkpoint['cfg'],
                                 width_mult=args.width_multiplier,
                                 input_mask=input_mask)
        else:
            raise NotImplementedError("{} is not supported".format(args.arch))

        if not args.scratch:
            # do not load weight parameters when retrain from scratch
            model.load_state_dict(checkpoint['state_dict'])

        if 'downsample_cfg' in checkpoint:
            # set downsample expand layer
            downsample_modules = list(
                filter(
                    lambda m: isinstance(m[1], nn.Sequential) and 'downsample'
                    in m[0], model.named_modules()))
            downsample_mask = checkpoint['downsample_mask']
            assert len(downsample_modules) == len(downsample_mask)
            for i, (name, m) in enumerate(downsample_modules):
                mask = downsample_mask[f"{name}.1"]
                assert mask.shape[0] == m[-1].idx.shape[0]
                m[-1].idx = np.argwhere(mask.clone().cpu().numpy()).squeeze()

        if args.arch == "mobilenetv2":
            # restore the mask of the Expand layer
            expand_idx = checkpoint['expand_idx']
            if expand_idx is not None:
                for m_name, sub_module in model.named_modules():
                    if isinstance(sub_module, models.common.ChannelOperation):
                        sub_module.idx = expand_idx[m_name]
            else:
                print(
                    "Warning: expand_idx not set in checkpoint. Use default settings."
                )

        # the mask changes the content of tensors
        # weights less than threshold will be set as zero
        # the mask operation must be done before data parallel
        weight_masks = None

        if not args.distributed:
            if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
                model.features = torch.nn.DataParallel(model.features)
                model.cuda()
            else:
                model = torch.nn.DataParallel(model).cuda()
        else:
            model.cuda()
            model = torch.nn.parallel.DistributedDataParallel(model)
    else:
        raise ValueError("--refine must be specified in finetuning!")

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    if args.no_bn_wd:
        no_wd_params = []
        for module_name, sub_module in model.named_modules():
            if isinstance(sub_module, nn.BatchNorm1d) or isinstance(
                    sub_module, nn.BatchNorm2d):
                for param_name, param in sub_module.named_parameters():
                    no_wd_params.append(param)
                    print(
                        f"No weight decay param: module {module_name} param {param_name}"
                    )
        no_wd_params_set = set(no_wd_params)
        wd_params = []
        for param_name, model_p in model.named_parameters():
            if model_p not in no_wd_params_set:
                wd_params.append(model_p)
                print(f"Weight decay param: parameter name {param_name}")

        optimizer = torch.optim.SGD([{
            'params': list(no_wd_params),
            'weight_decay': 0.
        }, {
            'params': list(wd_params),
            'weight_decay': args.weight_decay
        }],
                                    args.lr[0],
                                    momentum=args.momentum)
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr[0],
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            if isinstance(best_prec1, torch.Tensor):
                best_prec1 = best_prec1.item()

            for param_name, param in model.named_parameters():
                if param_name not in checkpoint['state_dict']:
                    checkpoint['state_dict'][param_name] = param.data
                    raise ValueError(
                        "Missing parameter {}, do not load!".format(
                            param_name))

            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])

            # move optimizer buffer to GPU
            for p in optimizer.state.keys():
                param_state = optimizer.state[p]
                buf = param_state["momentum_buffer"]
                param_state["momentum_buffer"] = buf.cuda(
                )  # move buf to device

            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            raise ValueError("=> no checkpoint found at '{}'".format(
                args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    training_transformers = [transforms.RandomResizedCrop(224)]
    if args.lighting:
        training_transformers.append(utils.common.Lighting(0.1))
    training_transformers += [
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), normalize
    ]

    train_dataset = datasets.ImageFolder(
        traindir, transforms.Compose(training_transformers))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion, epoch=0, writer=None)
        return

    # only master process write to disk
    writer = SummaryWriter(logdir=args.save, write_to_disk=args.rank == 0)

    if args.arch == "resnet50":
        summary = pruning_summary_resnet50(model, args.expand,
                                           args.width_multiplier)
    else:
        print("WARNING: arch {} do not support pretty print".format(args.arch))
        summary = str(model)

    print(model)

    print("********** MODEL SUMMARY **********")
    print(summary)
    print("********** ************* **********")

    writer.add_text("model/summary", summary)

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        weights = bn_weights(model)
        for bn_name, bn_weight in weights:
            writer.add_histogram("bn/" + bn_name, bn_weight, global_step=epoch)

        # train for one epoch
        train(train_loader,
              model,
              criterion,
              optimizer,
              epoch,
              writer,
              mask=weight_masks)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, epoch, writer)

        if args.debug:
            # make sure the prec1 is large enough to test saving functions
            prec1 = epoch

        writer.add_scalar("train/lr", optimizer.param_groups[0]['lr'], epoch)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            args.save,
            save_backup=(epoch - args.start_epoch) % 10 == 0,
            backup_path=args.backup_path,
            epoch=epoch)

    writer.close()
    print("Best prec@1: {}".format(best_prec1))
Пример #23
0
logger.save(str(args), 'args')

# data
dataset = SVHN(args.datadir)
logger.save(str(dataset), 'dataset')
test_list = dataset.getTestList(1000, True)

# model
start_iter = 0
lr = args.lr
if args.model == 'resnet':
    from resnet import ResNet18
    model = ResNet18().cuda()
elif args.model == 'vgg':
    from vgg import vgg11
    model = vgg11().cuda()
else:
    raise NotImplementedError()
criterion = CEwithMask
optimizer = torch.optim.SGD(model.parameters(),
                            lr=lr,
                            momentum=args.momentum,
                            weight_decay=args.weightdecay)
if args.resume:
    checkpoint = torch.load(args.resume)
    start_iter = checkpoint['iter'] + 1
    lr = checkpoint['lr']
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    logger.save("=> loaded checkpoint '{}'".format(args.resume))
logger.save(str(model), 'classifier')
Пример #24
0
            len_in *= x
            global_model = MLP(dim_in=len_in,
                               dim_hidden=64,
                               dim_out=args.num_classes)
    elif args.model == 'cnn_vc':
        global_model = CNNCifar_fedVC(args=args)
    elif args.model == 'cnn_vcbn':
        global_model = CNNCifar_VCBN(args=args)
    elif args.model == 'cnn_vcgn':
        global_model = CNNCifar_VCGN(args=args)
    elif args.model == 'resnet18':
        global_model = resnet18()
    elif args.model == 'resnet32':
        global_model = ResNet32_test()
    elif args.model == 'vgg':
        global_model = vgg11()
    elif args.model == 'cnn_ws':
        global_model = CNNCifar_WS(args=args)

    else:
        exit('Error: unrecognized model')

    # Set the model to train and send it to device.
    global_model.to(device)

    global_model.train()
    print(global_model)
    loggertxt.info(global_model)
    #visual용 direction 생성
    #rand_directions = create_random_direction(global_model)