示例#1
0
def get_model():
    # model = VGG16()
    model = AlexNet()

    model.compile(loss=keras.losses.categorical_crossentropy,
                  optimizer=keras.optimizers.Adam(learning_rate=0.001),
                  metrics=['accuracy'])
    return model
示例#2
0
def get_model(arch, wts_path):
    if arch == 'alexnet':
        model = AlexNet()
        model.fc = nn.Sequential()
        load_weights(model, wts_path)
    elif arch == 'pt_alexnet':
        model = models.alexnet()
        classif = list(model.classifier.children())[:5]
        model.classifier = nn.Sequential(*classif)
        load_weights(model, wts_path)
    elif arch == 'mobilenet':
        model = MobileNetV2()
        model.fc = nn.Sequential()
        load_weights(model, wts_path)
    elif arch == 'resnet50x5_swav':
        model = resnet50w5()
        model.l2norm = None
        load_weights(model, wts_path)
    elif 'resnet' in arch:
        model = models.__dict__[arch]()
        model.fc = nn.Sequential()
        load_weights(model, wts_path)
    else:
        raise ValueError('arch not found: ' + arch)

    for p in model.parameters():
        p.requires_grad = False

    return model
示例#3
0
def train(train_dataset, val_dataset, configs):

    train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size = configs["batch_size"],
            shuffle = True
    )

    val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size = configs["batch_size"],
            shuffle = False
    )

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = AlexNet().to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(params = model.parameters(), lr = configs["lr"])

    for epoch in range(configs["epochs"]):

        model.train()
        running_loss = 0.0
        correct = 0

        for i, (inputs, labels) in tqdm(enumerate(train_loader)):

            inputs, labels = inputs.to(device), labels.squeeze().to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()

            running_loss += loss.item()

        print("[%d] loss: %.4f" %
                  (epoch + 1, running_loss / train_dataset.__len__()))

        model.eval()
        correct = 0

        with torch.no_grad():

            for i, (inputs, labels) in tqdm(enumerate(val_loader)):

                inputs, labels = inputs.to(device), labels.squeeze().to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                correct += (predicted == labels).sum().item()

        print("Accuracy of the network on the %d test images: %.4f %%" %
                (val_dataset.__len__(), 100. * correct / val_dataset.__len__()))

    torch.save(model.state_dict(), "/opt/output/model.pt")
    def get_model(self):
        # Get modeling
        model_name = self.argument.model
        if model_name == 'AlexNet':
            from models.alexnet import AlexNet
            return AlexNet()

        elif model_name == 'vgg11':
            from models.vgg import VGG
            return VGG(vgg_type='VGG11')

        elif model_name == 'vgg13':
            from models.vgg import VGG
            return VGG(vgg_type='VGG13')

        elif model_name == 'vgg16':
            from models.vgg import VGG
            return VGG(vgg_type='VGG16')

        elif model_name == 'vgg19':
            from models.vgg import VGG
            return VGG(vgg_type='VGG19')

        elif model_name == 'GoogleNet_inception_v1':
            from models.googlenet import GoogleNetInceptionV1
            return GoogleNetInceptionV1()

        elif model_name == 'GoogleNet_inception_v2':
            from models.googlenet import GoogleNetInceptionV2
            return GoogleNetInceptionV2()

        else:
            raise NotImplemented()
示例#5
0
def load_model(model_name):
    if model_name == 'alexnet':
        from models.alexnet import AlexNet
        model = AlexNet()
        return model

    else:
        print("The model_name is not exists.")
def train(args):
	base_dir = args.base_dir
	num_classes = args.num_classes
	model_name = str.lower(args.model)

	if model_name== 'lenet_baseline':
		input_shape = (32, 32, 3)
		model = LeNet_baseline(input_shape=input_shape, num_classes=num_classes)
	elif model_name == 'lenet_modified':
		input_shape = (32, 32, 3)
		model = LeNet_modified(input_shape=input_shape, num_classes=num_classes)
	elif model_name == 'alexnet':
		input_shape = (227, 227, 3)
		model = AlexNet(input_shape=input_shape, num_classes=num_classes)
	elif model_name == 'vgg16':
		input_shape = (224, 224, 3)
		model = VGG16(input_shape=input_shape, num_classes=num_classes)
	elif model_name == 'vgg19':
		input_shape = (224, 224, 3)
		model = VGG19(input_shape=input_shape, num_classes=num_classes)
	elif model_name == 'resnet18':
		input_shape = (112, 112, 3)
		model = ResNet18(input_shape=input_shape, num_classes=num_classes)
	else:
		print('Please choose an implemented model!')
		sys.exit()

	# get training dataset
	x, y = input_preprocess(base_dir, input_shape, num_classes)
	# split dataset into train and val set
	x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.2, random_state=seed)
	print('Train images loaded')

	# one-hot encoding for each label
	y_train = one_hot(y_train, num_classes)
	y_val = one_hot(y_val, num_classes)

	# train the model
	callbacks = get_callbacks(args.checkpoint, model_name)
	model.compile(optimizer=Adam(lr=args.lr),
				  loss='categorical_crossentropy',
				  metrics=['acc'])
	model.summary()
	history = model.fit(x_train, y_train,
			  batch_size=args.batch_size,
			  validation_data=(x_val, y_val),
			  epochs=args.epochs,
			  callbacks=callbacks)
	visualize(history)
	return model
示例#7
0
def init_model(args, num_train_pids):

    print("Initializing model: {}".format(args.arch))
    if args.arch.lower() =='resnet50':
        model = ResNet50TP(num_classes=num_train_pids)
    elif args.arch.lower() =='alexnet':
        model = AlexNet(num_classes=num_train_pids)
    else:
        assert False, 'unknown model ' + args.arch

    # pretrained model loading
    if args.pretrained_model is not None:
        model = load_pretrained_model(model, args.pretrained_model)
    
    return model
示例#8
0
 def __init__(self, pretrained_dir=os.path.join(CURR_DIR, "models",
                                                "init")):
     super(GONET, self).__init__()
     #alexnet = models.alexnet()
     #alexnet.load_state_dict(torch.load(osp.join(CURR_DIR, "alexnet.pth")))
     #self.features = alexnet.features
     from models.alexnet import AlexNet
     self.features = AlexNet()
     self.regressor = nn.Sequential(nn.Linear(256 * 6 * 6 * 2,
                                              4096), nn.ReLU(inplace=True),
                                    nn.Dropout(p=0.5),
                                    nn.Linear(4096,
                                              4096), nn.ReLU(inplace=True),
                                    nn.Dropout(p=0.5),
                                    nn.Linear(4096,
                                              4096), nn.ReLU(inplace=True),
                                    nn.Dropout(p=0.5), nn.Linear(4096, 4))
     self.weight_init(pretrained_dir)
def test(args):
    base_dir = args.base_dir
    num_classes = args.num_classes
    model_name = str.lower(args.model)

    if model_name == 'lenet_baseline':
        input_shape = (32, 32, 3)
        model = LeNet_baseline(input_shape=input_shape,
                               num_classes=num_classes)
    elif model_name == 'lenet_modified':
        input_shape = (32, 32, 3)
        model = LeNet_modified(input_shape=input_shape,
                               num_classes=num_classes)
    elif model_name == 'alexnet':
        input_shape = (227, 227, 3)
        model = AlexNet(input_shape=input_shape, num_classes=num_classes)
    elif model_name == 'vgg16':
        input_shape = (224, 224, 3)
        model = VGG16(input_shape=input_shape, num_classes=num_classes)
    elif model_name == 'vgg19':
        input_shape = (224, 224, 3)
        model = VGG19(input_shape=input_shape, num_classes=num_classes)
    elif model_name == 'resnet18':
        input_shape = (112, 112, 3)
        model = ResNet18(input_shape=input_shape, num_classes=num_classes)
    else:
        print('Please choose an implemented model!')
        sys.exit()

    # load test set
    x, y = load_test(base_dir, input_shape, num_classes)
    print('Test images loaded')

    model.load_weights(args.pretrained_model)
    pred = model.predict_classes(x)
    accuracy = accuracy_score(y, pred)
    print('Accuracy: {}'.format(accuracy))
    print(classification_report(y, pred))
示例#10
0
def main():
    dataset = Cifar10()
    # dataset = Cifar100()
    # dataset = Mnist()

    model = AlexNet()
    # model = VGG()
    # model = GoogLeNet()
    # model = ResNet()

    # training
    trainer = ClfTrainer(model, dataset)
    trainer.run_training(epochs, batch_size, learning_rate, './test-ckpt')
    #trainer.run_training(epochs, batch_size, learning_rate, './test-ckpt', options={'model_type': ... })

    # resuming training
    trainer.resume_training_from_ckpt(epochs, batch_size, learning_rate, './test-ckpt', './new-test-ckpt')
    #trainer.resume_training_from_ckpt(epochs, batch_size, learning_rate, './test-ckpt', './new-test-ckpt', options={'model_type': ... })

    # transfer learning
    new_dataset = Cifar100()
    trainer = ClfTrainer(model, new_dataset)
    trainer.run_transfer_learning(epochs, batch_size, learning_rate, './new-test-ckpt-1', './test-transfer-learning-ckpt')
示例#11
0
def Model(args):

    # TODO: Fix args.pretrained
    if args.model == "softmax":
        model = Softmax(args.image_size, args.no_of_classes)
    elif args.model == "twolayernn":
        model = TwoLayerNN(args.image_size, args.no_of_classes)
    elif args.model == "threelayernn":
        model = ThreeLayerNN(args.image_size, args.no_of_classes)
    elif args.model == "onelayercnn":
        model = OneLayerCNN(args.image_size, args.no_of_classes)
    elif args.model == "twolayercnn":
        model = TwoLayerCNN(args.image_size, args.no_of_classes)
    elif args.model == "vggnet":
        model = VGGNet(args.image_size, args.no_of_classes)
    elif args.model == "alexnet":
        model = AlexNet(args.image_size, args.no_of_classes)
    elif args.model == "resnet":
        model = ResNet18()
        # self.model = models.resnet18(pretrained=True)
    else:
        raise Exception("Unknown model {}".format(args.model))

    return model
示例#12
0
 def __init__(self, init_dir=os.path.join(CURR_DIR, "models", "init")):
     super(GONET, self).__init__()
     #alexnet = models.alexnet()
     #alexnet.load_state_dict(torch.load(osp.join(CURR_DIR, "alexnet.pth")))
     #self.features = alexnet.features
     from models.alexnet import AlexNet
     self.features = AlexNet()
     self.corr = Correlation(pad_size=3,
                             kernel_size=1,
                             max_displacement=3,
                             stride1=1,
                             stride2=2,
                             corr_multiply=1)
     self.conv_redir = nn.Conv2d(256, 256, kernel_size=1)
     self.regressor = nn.Sequential(nn.Linear(265 * 6 * 6,
                                              4096), nn.ReLU(inplace=True),
                                    nn.Dropout(p=0.5),
                                    nn.Linear(4096,
                                              4096), nn.ReLU(inplace=True),
                                    nn.Dropout(p=0.5),
                                    nn.Linear(4096,
                                              4096), nn.ReLU(inplace=True),
                                    nn.Dropout(p=0.5), nn.Linear(4096, 4))
     self.weight_init(init_dir)
示例#13
0
# 获取训练集、测试集的加载器
# train_loader, valid_loader = cfg.dataset_loader(root=cfg.cat_dog_train, train=True,
#                                                 data_preprocess=[train_data_preprocess, valid_data_preprocess],
#                                                 valid_coef=0.1)

train_loader = cfg.dataset_loader(root=cfg.cat_dog_train, train=True,
                                  data_preprocess=train_data_preprocess)
valid_loader = cfg.dataset_loader(root=cfg.cat_dog_valid, train=True,
                                  data_preprocess=valid_data_preprocess)
# test_loader = cfg.dataset_loader(root=cfg.cat_dog_test, train=False, shuffle=False,
#                                  data_preprocess=valid_data_preprocess)

# ---------------构建网络、定义损失函数、优化器--------------------------
# 构建网络结构
# net = resnet()
net = AlexNet(num_classes=cfg.num_classes)
# net = resnet50()
#net = resnet18()
# 重写网络最后一层
#fc_in_features = net.fc.in_features  # 网络最后一层的输入通道
#net.fc = nn.Linear(in_features=fc_in_features, out_features=cfg.num_classes)

# 将网络结构、损失函数放置在GPU上;配置优化器
net = net.to(cfg.device)
# net = nn.DataParallel(net, device_ids=[0, 1])
# criterion=nn.BCELoss()
#criterion = nn.BCEWithLogitsLoss().cuda(device=cfg.device)
criterion = nn.CrossEntropyLoss().cuda(device=cfg.device)
# 常规优化器:随机梯度下降和Adam
#optimizer = optim.SGD(params=net.parameters(), lr=cfg.learning_rate,
#                      weight_decay=cfg.weight_decay, momentum=cfg.momentum)
def main():
    global best_acc
    start_epoch = args.start_epoch  # start from epoch 0 or last checkpoint epoch

    if not os.path.isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    print_training_params(args=args, txt_file_path=txt_file_path)

    # Data
    print(f'==> Preparing dataset {args.dataset}')
    if args.dataset in ['cifar10', 'cifar100']:
        detph = 28
        widen_factor = 10
        dropout = 0.3
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor()
        ])

    elif args.dataset == 'tiny-imagenet':
        transform_train = transforms.Compose([
            transforms.ToTensor()
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor()
        ])

    elif args.dataset == 'imagenet':
        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])
        transform_test = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor()
        ])

    elif args.dataset == 'mnist':
        transform_train = transforms.Compose([
            transforms.ToTensor()
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor()
        ])

    elif args.dataset == 'SVHN':
        detph = 16
        widen_factor = 4
        dropout = 0.4
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.ToTensor()
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor()
        ])

    print(f'Running on dataset {args.dataset}')
    if args.dataset in ['cifar10', 'cifar100', 'mnist']:
        if args.dataset == 'cifar10':
            dataloader = datasets.CIFAR10
            num_classes = 10
        elif args.dataset == 'cifar100':
            dataloader = datasets.CIFAR100
            num_classes = 100
        elif args.dataset == 'mnist':
            dataloader = datasets.MNIST
            num_classes = 10

        trainset = dataloader(root='.data', train=True, download=True, transform=transform_train)
        testset = dataloader(root='.data', train=False, download=False, transform=transform_test)

    elif args.dataset == 'imagenet':
        trainset = datasets.ImageFolder('imagenet/train', transform=transform_train)
        testset = datasets.ImageFolder('imagenet/val', transform=transform_test)
        num_classes = 1000

    elif args.dataset == 'SVHN':
        trainset = datasets.SVHN('data', split='train', transform=transform_train, download=True)
        testset = datasets.SVHN('data', split='test', transform=transform_test, download=True)
        num_classes = 10
    
    trainloader = data.DataLoader(trainset, batch_size=args.train_batch, shuffle=True, num_workers=args.workers)
    testloader = data.DataLoader(testset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers)

    # Model
    print("==> creating model '{}'".format(args.arch))
    if args.arch == 'vgg16':
        model = VGG16(
            dataset=args.dataset,
            num_classes=num_classes,
            kernels1=args.kernels1,
            kernels2=args.kernels2,
            kernels3=args.kernels3,
            orientations=args.orientations,
            learn_theta=args.learn_theta,
            finetune=args.finetune
        )

    elif args.arch == 'resnet18':
        model = ResNet18(
            dataset=args.dataset,
            num_classes=num_classes,
            kernels1=args.kernels1,
            kernels2=args.kernels2,
            kernels3=args.kernels3,
            orientations=args.orientations,
            learn_theta=args.learn_theta,
            finetune=args.finetune
        )

    elif args.arch == 'madry':
        model = MadryNet(
            kernels1=args.kernels1,
            kernels2=args.kernels2,
            kernels3=args.kernels3,
            orientations=args.orientations,
            learn_theta=args.learn_theta
        )

    elif args.arch == 'lenet':
        model = LeNet(
            kernels1=args.kernels1,
            kernels2=args.kernels2,
            kernels3=args.kernels3,
            orientations=args.orientations,
            learn_theta=args.learn_theta
        )

    elif args.arch == 'alexnet':
        model = AlexNet(
            dataset=args.dataset,
            num_classes=num_classes,
            kernels1=args.kernels1,
            kernels2=args.kernels2,
            kernels3=args.kernels3,
            orientations=args.orientations,
            learn_theta=args.learn_theta
        )

    elif args.arch == 'wide-resnet':
        model = Wide_ResNet(
            dataset=args.dataset,
            num_classes=num_classes,
            kernels1=args.kernels1,
            kernels2=args.kernels2,
            kernels3=args.kernels3,
            orientations=args.orientations,
            learn_theta=args.learn_theta,
            finetune=args.finetune,
            depth=detph,
            widen_factor=widen_factor,
            dropout_rate=dropout,
            use_7x7=args.use_7x7
        )

    print('Model:')
    print(model)
    print_to_log(text=repr(model), txt_file_path=txt_file_path)
    
    if device == 'cuda':
        model = torch.nn.DataParallel(model).to(device)
    
    # Compute number of parameters and print them
    param_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
    param_txt = f'    Total trainable params: {param_num:d}'
    print_to_log(text=param_txt, txt_file_path=txt_file_path)
    print(param_txt)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    # Resume
    title = f'{args.dataset}-' + args.arch
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint...')
        assert osp.isfile(args.resume), 'Error: no checkpoint directory found!'
        args.checkpoint = osp.dirname(args.resume)
        checkpoint = torch.load(args.resume)
        best_acc = checkpoint['best_acc']
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True)
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'])


    if args.evaluate:
        print('\nEvaluation only')
        test_loss, test_acc = test(testloader, model, criterion, start_epoch, device)
        print(' Test Loss:  %.8f, Test Acc:  %.2f' % (test_loss, test_acc))
        return

    # Train and val
    for epoch in range(start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, state['lr']))

        train_loss, train_acc = train(
            trainloader, model, criterion, optimizer, epoch, device, train_adv=args.train_adv, args=args)
        test_loss, test_acc = test(
            testloader, model, criterion, epoch, device)

        # append logger file
        logger.append([state['lr'], train_loss, test_loss, train_acc, test_acc])

        # save model
        is_best = test_acc > best_acc
        best_acc = max(test_acc, best_acc)
        save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'acc': test_acc,
                'best_acc': best_acc,
                'optimizer' : optimizer.state_dict(),
            }, is_best, checkpoint=args.checkpoint)

        if args.kernels1 is not None:
            plot_kernels(model, args.checkpoint, epoch, device)

    logger.close()
    logger.plot()
    savefig(os.path.join(args.checkpoint, 'log.eps'))

    print('Best acc:')
    print(best_acc)

    print('Training finished. Running attack')
    main_attack(args)

    print('Running SVD computation')
    main_svs_computation(args)
示例#15
0
def main_attack(args):
    # Use CUDA
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
    use_cuda = torch.cuda.is_available()
    device = 'cuda' if use_cuda else 'cpu'

    # Random seed
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    if use_cuda:
        torch.cuda.manual_seed_all(args.seed)

    # Data
    print(f'==> Preparing dataset {args.dataset}')
    if args.dataset in ['cifar10', 'cifar100']:
        detph = 28
        widen_factor = 10
        dropout = 0.3
        transform_test = transforms.Compose([transforms.ToTensor()])

    elif args.dataset == 'tiny-imagenet':
        transform_test = transforms.Compose([transforms.ToTensor()])

    elif args.dataset == 'imagenet':
        transform_test = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor()
        ])
    elif args.dataset == 'mnist':
        transform_test = transforms.Compose([transforms.ToTensor()])
    elif args.dataset == 'SVHN':
        detph = 16
        widen_factor = 4
        dropout = 0.4
        transform_train = transforms.Compose(
            [transforms.RandomCrop(32, padding=4),
             transforms.ToTensor()])
        transform_test = transforms.Compose([transforms.ToTensor()])

    print(f'Running on dataset {args.dataset}')
    if args.dataset in ['cifar10', 'cifar100', 'mnist']:
        if args.dataset == 'cifar10':
            dataloader = datasets.CIFAR10
            num_classes = 10
        elif args.dataset == 'cifar100':
            dataloader = datasets.CIFAR100
            num_classes = 100
        elif args.dataset == 'mnist':
            dataloader = datasets.MNIST
            num_classes = 10

        testset = dataloader(root='.data',
                             train=False,
                             download=False,
                             transform=transform_test)

    elif args.dataset == 'tiny-imagenet':
        testset = datasets.ImageFolder('tiny-imagenet-200/val',
                                       transform=transform_test)
        num_classes = 200

    elif args.dataset == 'imagenet':
        testset = datasets.ImageFolder('imagenet/val',
                                       transform=transform_test)
        num_classes = 1000

    elif args.dataset == 'SVHN':
        testset = datasets.SVHN('data',
                                split='test',
                                transform=transform_test,
                                download=True)
        num_classes = 10

    testloader = data.DataLoader(testset,
                                 batch_size=args.test_batch,
                                 shuffle=False,
                                 num_workers=args.workers)
    # Model
    if args.arch == 'vgg16':
        model = VGG16(dataset=args.dataset,
                      num_classes=num_classes,
                      kernels1=args.kernels1,
                      kernels2=args.kernels2,
                      kernels3=args.kernels3,
                      orientations=args.orientations,
                      learn_theta=args.learn_theta)
    elif args.arch == 'resnet18':
        model = ResNet18(dataset=args.dataset,
                         num_classes=num_classes,
                         kernels1=args.kernels1,
                         kernels2=args.kernels2,
                         kernels3=args.kernels3,
                         orientations=args.orientations,
                         learn_theta=args.learn_theta)
    elif args.arch == 'madry':
        model = MadryNet(kernels1=args.kernels1,
                         kernels2=args.kernels2,
                         kernels3=args.kernels3,
                         orientations=args.orientations,
                         learn_theta=args.learn_theta)
    elif args.arch == 'lenet':
        model = LeNet(kernels1=args.kernels1,
                      kernels2=args.kernels2,
                      kernels3=args.kernels3,
                      orientations=args.orientations,
                      learn_theta=args.learn_theta)
    elif args.arch == 'alexnet':
        model = AlexNet(dataset=args.dataset,
                        num_classes=num_classes,
                        kernels1=args.kernels1,
                        kernels2=args.kernels2,
                        kernels3=args.kernels3,
                        orientations=args.orientations,
                        learn_theta=args.learn_theta)
    elif args.arch == 'wide-resnet':
        model = Wide_ResNet(dataset=args.dataset,
                            num_classes=num_classes,
                            kernels1=args.kernels1,
                            kernels2=args.kernels2,
                            kernels3=args.kernels3,
                            orientations=args.orientations,
                            learn_theta=args.learn_theta,
                            finetune=False,
                            depth=detph,
                            widen_factor=widen_factor,
                            dropout_rate=dropout,
                            use_7x7=args.use_7x7)

    print('Model:')
    print(model)

    if use_cuda:
        model = torch.nn.DataParallel(model).cuda()

    # Compute number of parameters and print them
    param_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
    param_txt = f'    Total trainable params: {param_num:d}'
    print(param_txt)

    criterion = nn.CrossEntropyLoss()
    # Resume
    # Load checkpoint.
    print('==> Resuming from checkpoint...')
    checkpoint_filename = osp.join(args.checkpoint, 'model_best.pth.tar')
    assert osp.isfile(
        checkpoint_filename), 'Error: no checkpoint directory found!'
    checkpoint = torch.load(checkpoint_filename)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])

    print('\nEvaluation only')
    test_loss, test_acc = test(testloader, model, criterion, start_epoch,
                               use_cuda)
    print(' Test Loss:  %.8f, Test Acc:  %.2f' % (test_loss, test_acc))
    print(f'Running {args.attack} attack!')

    if args.attack == 'cw':
        c_vals = torch.logspace(start=-2, end=2, steps=9)
        for c in c_vals:
            print(f'Running attack with c = {c:5.3f}')
            attack_cw(model, testloader, device=device, c=c)
            print('\n')
    else:
        if args.dataset == 'mnist':
            epsilons = [.1, .2, .3, .4]
        else:
            epsilons = [2 / 255, 8 / 255, 16 / 255, .1]
        print(f'Epsilons are: {epsilons}')
        minimum = 0.
        maximum = 1.
        print(f'Images maxima: {maximum} -- minima: {minimum}')
        df = {
            'epsilons': [
                0.,
            ],
            'test_set_accs': [
                test_acc,
            ],
            'flip_rates': [
                0.,
            ],
        }
        for eps in epsilons:
            print(f'Running attack with epsilon = {eps:5.3f}')
            acc_test_set, flip_rate = attack_pgd(model,
                                                 testloader,
                                                 device=device,
                                                 minimum=minimum,
                                                 maximum=maximum,
                                                 eps=eps)
            df['epsilons'].append(eps)
            df['test_set_accs'].append(acc_test_set)
            df['flip_rates'].append(flip_rate)
            print('\n')
        df = pd.DataFrame.from_dict(df)
        print('Overall results: \n', df)
        filename = osp.join(args.checkpoint, 'attack_results.csv')
        df.to_csv(filename, index=False)
    if not os.path.exists("logs"):
        os.makedirs('logs')

    while os.path.exists("logs/log%s.txt" % i):
        i += 1

    # Initialize log path
    LOG_PATH = "logs/log%s.txt" % i

    def print(msg):
        with open(LOG_PATH, 'a') as f:
            f.write(f'{time.ctime()}: {msg}\n')

    # Get the configuration
    cfg = Configuration()
    net = AlexNet(cfg, training=True)

    # If it is resume task, make it true
    resume = False

    # Path for train dataset
    path = 'cifar-10-batches-py/data_batch_'

    # Get the data set using data.py
    trainingset = ImageNetDataset(cfg, 'train', path)

    # Path for valdidation dataset
    path_val = 'cifar-10-batches-py/data_batch_1'
    valset = ImageNetDataset(cfg, 'val', path_val)

    # Make the Checkpoint path
示例#17
0
def main():
    parser = argparse.ArgumentParser(description='Process some integers.')
    parser.add_argument('--file_path',
                        type=str,
                        default="image_data",
                        help='train data path')
    parser.add_argument('--reshape_size',
                        type=int,
                        default=227,
                        help='reshape size')
    parser.add_argument('--model',
                        type=str,
                        default="alexnet",
                        help='alexnet or vgg11 or deepsimplenet')
    args = parser.parse_args()
    print("Train Start")

    # root_dir = "dataset_nakamoto_inf"
    train_dir = args.file_path

    #カテゴリーを指定

    categories = [
        "gomoku_miso", "gomoku_moko", "hiyasi_gomoku", "hiyasi_miso",
        "hokkyoku", "hokkyoku_yasai", "miso", "miso_ran", "moko", "moko_ran",
        "sio"
    ]

    t1 = time.time()

    # 画像データ用配列
    X = []
    # ラベルデータ用配列
    Y = []
    #全データ格納用配列
    allfiles = []

    #カテゴリ配列の各値と、それに対応するidxを認識し、全データをallfilesにまとめる
    for idx, item in enumerate(categories):
        image_dir = train_dir + "/" + item
        files = glob.glob(image_dir + "/*")
        # file種類に合わせて変更
        # files = glob.glob(image_dir + "/*.png")
        #     files = glob.glob(image_dir + "/*.jpg")
        # files = glob.glob(image_dir + "/*.jpeg")
        for f in files:
            allfiles.append((idx, f))

    #シャッフル後、学習データと検証データに分ける
    random.shuffle(allfiles)
    th = math.floor(len(allfiles) * 0.8)
    train = allfiles[0:th]
    test = allfiles[th:]
    X_train, y_train = make_sample(train, args.reshape_size)
    X_test, y_test = make_sample(test, args.reshape_size)

    X_train = X_train.astype(np.float32)
    X_train /= 255.0

    X_test = X_test.astype(np.float32)
    X_test /= 255.0

    #testとvalidationを分ける
    th = math.floor(len(y_test) * 0.5)
    X_val, y_val = X_test[:th], y_test[:th]
    X_test, y_test = X_test[th:], y_test[th:]

    x_train, t_train = X_train.transpose(0, 3, 1, 2), y_train
    x_val, t_val = X_val.transpose(0, 3, 1, 2), y_val
    x_test, t_test = X_test.transpose(0, 3, 1, 2), y_test

    max_epochs = 10
    """
    if config.GPU:
        x_train, t_train = to_gpu(x_train), to_gpu(t_train)
        x_test, t_test = to_gpu(x_test), to_gpu(t_test)
    """

    if args.model == "alexnet":
        network = AlexNet(input_dim=(3, args.reshape_size, args.reshape_size),
                          output_size=len(categories))
    elif args.model == "vgg11":
        network = AlexNet(input_dim=(3, args.reshape_size, args.reshape_size),
                          output_size=len(categories))
    elif args.model == "deepsimplenet":
        network = AlexNet(input_dim=(3, args.reshape_size, args.reshape_size),
                          output_size=len(categories))

    #訓練開始
    trainer = Trainer(network,
                      x_train,
                      t_train,
                      x_val,
                      t_val,
                      x_test,
                      t_test,
                      epochs=max_epochs,
                      mini_batch_size=10,
                      optimizer='Adam',
                      optimizer_param={'lr': 0.0001},
                      evaluate_sample_num_per_epoch=50)

    trainer.train()

    # パラメータの保存
    #{名}_{network名}.pkl
    network.save_params("results/test" + args.model + ".pkl")
    print("Saved Network Parameters!")

    t2 = time.time()
    elapsed_time = t2 - t1
    print(f"time:{elapsed_time}")

    #訓練結果を描画

    markers = {'train': 'o', 'test': 's'}
    x = np.arange(max_epochs)

    ###できれば複数グラフにする
    plt.plot(x, trainer.train_acc_list, marker='o', label='train', markevery=2)
    plt.plot(x,
             trainer.val_acc_list,
             marker='s',
             label='validation',
             markevery=2)
    plt.xlabel("epochs")
    plt.ylabel("accuracy")
    plt.ylim(0, 1.0)
    plt.legend(loc='lower right')
    plt.title("Training and validation accuracy")
    # plt.show()
    plt.savefig("results/accuracy.png")
示例#18
0
def test_alexnet():
    data = Tensor(np.ones([32, 3, 227, 227]).astype(np.float32) * 0.01)
    label = Tensor(np.ones([32]).astype(np.int32))
    net = AlexNet()
    train(net, data, label)
def extract_feature(args):
    """Extract and save features for train split, several clips per video."""
    torch.backends.cudnn.benchmark = True
    # Force the pytorch to create context on the specific device 
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")

    ########### model ##############
    model = AlexNet(with_classifier=False, return_conv=True) .to(device)

    if args.ckpt:
        pretrained_weights = load_pretrained_weights(args.ckpt)
        model.load_state_dict(pretrained_weights, strict=True)
    model.eval()
    torch.set_grad_enabled(False)
    ### Exract for train split ###
    train_transforms = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224),
        transforms.ToTensor()
    ])
    train_dataset = UCF101FrameRetrievalDataset('data/ucf101', 10, True, train_transforms)
    train_dataloader = DataLoader(train_dataset, batch_size=args.bs, shuffle=False,
                                    num_workers=args.workers, pin_memory=True, drop_last=True)
    
    features = []
    classes = []
    for data in tqdm(train_dataloader):
        sampled_clips, idxs = data
        clips = sampled_clips.reshape((-1, 3, 224, 224))
        inputs = clips.to(device)
        # forward
        outputs = model(inputs)
        # print(outputs.shape)
        # exit()
        features.append(outputs.cpu().numpy().tolist())
        classes.append(idxs.cpu().numpy().tolist())

    features = np.array(features).reshape(-1, 10, outputs.shape[1])
    classes = np.array(classes).reshape(-1, 10)
    np.save(os.path.join(args.feature_dir, 'train_feature.npy'), features)
    np.save(os.path.join(args.feature_dir, 'train_class.npy'), classes)

    ### Exract for test split ###
    test_transforms = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224),
        transforms.ToTensor()
    ])
    test_dataset = UCF101FrameRetrievalDataset('data/ucf101', 10, False, test_transforms)
    test_dataloader = DataLoader(test_dataset, batch_size=args.bs, shuffle=False,
                                    num_workers=args.workers, pin_memory=True, drop_last=True)

    features = []
    classes = []
    for data in tqdm(test_dataloader):
        sampled_clips, idxs = data
        clips = sampled_clips.reshape((-1, 3, 224, 224))
        inputs = clips.to(device)
        # forward
        outputs = model(inputs)
        features.append(outputs.cpu().numpy().tolist())
        classes.append(idxs.cpu().numpy().tolist())

    features = np.array(features).reshape(-1, 10, outputs.shape[1])
    classes = np.array(classes).reshape(-1, 10)
    np.save(os.path.join(args.feature_dir, 'test_feature.npy'), features)
    np.save(os.path.join(args.feature_dir, 'test_class.npy'), classes)
示例#20
0
            logit = model(input_data_tensor.permute([0, 3, 1, 2]))
            pred_cls = torch.argmax(logit, -1)

            P += (pred_cls == input_label_tensor).sum().cpu().detach().numpy()
            N += HyperParams["batch_size"]
        if idx % 500 == 499:
            print("|acc:%f|use time:%s|" %
                  (float(P / N), str(time.time() - start_time)))
            start_time = time.time()

            # print('')


if __name__ == '__main__':
    train_data = mnist.MNIST("./mnist_data")
    model = AlexNet(10)
    if HyperParams["cuda"]:
        model = model.cuda()
    optimer = torch.optim.Adam(params=[{
        "params": model.parameters()
    }],
                               lr=0.004)
    lr_sch = torch.optim.lr_scheduler.MultiStepLR(optimer, [1, 2, 3, 4], 0.1)
    criterion = torch.nn.CrossEntropyLoss()
    static_params = torch.load("./%s_E%d.snap" %
                               (HyperParams["model_save_prefix"], 4))
    model.load_state_dict(static_params)
    # trainval(model,optimer,lr_sch,criterion,train_data)
    if HyperParams["quantize"]:
        model = torch.quantization.quantize_dynamic(model)
    torch.save(model.state_dict(), "./quantize_mode.snap")
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--stage', default='train', type=str)
    parser.add_argument('--dataset', default='imagenet', type=str)
    parser.add_argument('--lr', default=0.0012, type=float)
    parser.add_argument('--batch_size', default=128, type=int)
    parser.add_argument('--gpus', default='0,1,2,3', type=str)
    parser.add_argument('--weight_decay', default=1e-5, type=float)
    parser.add_argument('--max_epoch', default=30, type=int)
    parser.add_argument('--lr_decay_steps', default='15,20,25', type=str)
    parser.add_argument('--exp', default='', type=str)
    parser.add_argument('--list', default='', type=str)
    parser.add_argument('--resume_path', default='', type=str)
    parser.add_argument('--pretrain_path', default='', type=str)
    parser.add_argument('--n_workers', default=32, type=int)

    parser.add_argument('--network', default='resnet50', type=str)

    global args
    args = parser.parse_args()

    if not os.path.exists(args.exp):
        os.makedirs(args.exp)
    if not os.path.exists(os.path.join(args.exp, 'runs')):
        os.makedirs(os.path.join(args.exp, 'runs'))
    if not os.path.exists(os.path.join(args.exp, 'models')):
        os.makedirs(os.path.join(args.exp, 'models'))
    if not os.path.exists(os.path.join(args.exp, 'logs')):
        os.makedirs(os.path.join(args.exp, 'logs'))

    # logger initialize
    logger = getLogger(args.exp)

    device_ids = list(map(lambda x: int(x), args.gpus.split(',')))
    device = torch.device('cuda: 0')

    train_loader, val_loader = cifar.get_semi_dataloader(
        args) if args.dataset.startswith(
            'cifar') else imagenet.get_semi_dataloader(args)

    # create model
    if args.network == 'alexnet':
        network = AlexNet(128)
    elif args.network == 'alexnet_cifar':
        network = AlexNet_cifar(128)
    elif args.network == 'resnet18_cifar':
        network = ResNet18_cifar()
    elif args.network == 'resnet50_cifar':
        network = ResNet50_cifar()
    elif args.network == 'wide_resnet28':
        network = WideResNet(28, args.dataset == 'cifar10' and 10 or 100, 2)
    elif args.network == 'resnet18':
        network = resnet18()
    elif args.network == 'resnet50':
        network = resnet50()
    network = nn.DataParallel(network, device_ids=device_ids)
    network.to(device)

    classifier = nn.Linear(2048, 1000).to(device)
    # create optimizer

    parameters = network.parameters()
    optimizer = torch.optim.SGD(
        parameters,
        lr=args.lr,
        momentum=0.9,
        weight_decay=args.weight_decay,
    )

    cls_optimizer = torch.optim.SGD(
        classifier.parameters(),
        lr=args.lr * 50,
        momentum=0.9,
        weight_decay=args.weight_decay,
    )

    cudnn.benchmark = True

    # create memory_bank
    global writer
    writer = SummaryWriter(comment='SemiSupervised',
                           logdir=os.path.join(args.exp, 'runs'))

    # create criterion
    criterion = nn.CrossEntropyLoss()

    logging.info(beautify(args))
    start_epoch = 0
    if args.pretrain_path != '' and args.pretrain_path != 'none':
        logging.info('loading pretrained file from {}'.format(
            args.pretrain_path))
        checkpoint = torch.load(args.pretrain_path)
        state_dict = checkpoint['state_dict']
        valid_state_dict = {
            k: v
            for k, v in state_dict.items()
            if k in network.state_dict() and 'fc.' not in k
        }
        for k, v in network.state_dict().items():
            if k not in valid_state_dict:
                logging.info('{}: Random Init'.format(k))
                valid_state_dict[k] = v
        # logging.info(valid_state_dict.keys())
        network.load_state_dict(valid_state_dict)
    else:
        logging.info('Training SemiSupervised Learning From Scratch')

    logging.info('start training')
    best_acc = 0.0
    try:
        for i_epoch in range(start_epoch, args.max_epoch):
            train(i_epoch, network, classifier, criterion, optimizer,
                  cls_optimizer, train_loader, device)

            checkpoint = {
                'epoch': i_epoch + 1,
                'state_dict': network.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            torch.save(checkpoint,
                       os.path.join(args.exp, 'models', 'checkpoint.pth'))
            adjust_learning_rate(args.lr_decay_steps, optimizer, i_epoch)
            if i_epoch % 2 == 0:
                acc1, acc5 = validate(i_epoch, network, classifier, val_loader,
                                      device)
                if acc1 >= best_acc:
                    best_acc = acc1
                    torch.save(checkpoint,
                               os.path.join(args.exp, 'models', 'best.pth'))
                writer.add_scalar('acc1', acc1, i_epoch + 1)
                writer.add_scalar('acc5', acc5, i_epoch + 1)

            if i_epoch in [30, 60, 120, 160, 200]:
                torch.save(
                    checkpoint,
                    os.path.join(args.exp, 'models',
                                 '{}.pth'.format(i_epoch + 1)))

            logging.info(
                colorful('[Epoch: {}] val acc: {:.4f}/{:.4f}'.format(
                    i_epoch, acc1, acc5)))
            logging.info(
                colorful('[Epoch: {}] best acc: {:.4f}'.format(
                    i_epoch, best_acc)))

            with torch.no_grad():
                for name, param in network.named_parameters():
                    if 'bn' not in name:
                        writer.add_histogram(name, param, i_epoch)

            # cluster
    except KeyboardInterrupt as e:
        logging.info('KeyboardInterrupt at {} Epochs'.format(i_epoch))
        exit()
示例#22
0
def get_network(name: str, num_classes: int) -> None:
    return \
        AlexNet(
            num_classes=num_classes) if name == 'AlexNet' else\
        DenseNet201(
            num_classes=num_classes) if name == 'DenseNet201' else\
        DenseNet169(
            num_classes=num_classes) if name == 'DenseNet169' else\
        DenseNet161(
            num_classes=num_classes) if name == 'DenseNet161' else\
        DenseNet121(
            num_classes=num_classes) if name == 'DenseNet121' else\
        DenseNet121CIFAR(
            num_classes=num_classes) if name == 'DenseNet121CIFAR' else\
        GoogLeNet(
            num_classes=num_classes) if name == 'GoogLeNet' else\
        InceptionV3(
            num_classes=num_classes) if name == 'InceptionV3' else\
        MNASNet_0_5(
            num_classes=num_classes) if name == 'MNASNet_0_5' else\
        MNASNet_0_75(
            num_classes=num_classes) if name == 'MNASNet_0_75' else\
        MNASNet_1(
            num_classes=num_classes) if name == 'MNASNet_1' else\
        MNASNet_1_3(
            num_classes=num_classes) if name == 'MNASNet_1_3' else\
        MobileNetV2(
            num_classes=num_classes) if name == 'MobileNetV2' else\
        ResNet18(
            num_classes=num_classes) if name == 'ResNet18' else\
        ResNet34(
            num_classes=num_classes) if name == 'ResNet34' else\
        ResNet34CIFAR(
            num_classes=num_classes) if name == 'ResNet34CIFAR' else\
        ResNet50CIFAR(
            num_classes=num_classes) if name == 'ResNet50CIFAR' else\
        ResNet101CIFAR(
            num_classes=num_classes) if name == 'ResNet101CIFAR' else\
        ResNet18CIFAR(
            num_classes=num_classes) if name == 'ResNet18CIFAR' else\
        ResNet50(
            num_classes=num_classes) if name == 'ResNet50' else\
        ResNet101(
            num_classes=num_classes) if name == 'ResNet101' else\
        ResNet152(
            num_classes=num_classes) if name == 'ResNet152' else\
        ResNeXt50(
            num_classes=num_classes) if name == 'ResNext50' else\
        ResNeXtCIFAR(
            num_classes=num_classes) if name == 'ResNeXtCIFAR' else\
        ResNeXt101(
            num_classes=num_classes) if name == 'ResNext101' else\
        WideResNet50(
            num_classes=num_classes) if name == 'WideResNet50' else\
        WideResNet101(
            num_classes=num_classes) if name == 'WideResNet101' else\
        ShuffleNetV2_0_5(
            num_classes=num_classes) if name == 'ShuffleNetV2_0_5' else\
        ShuffleNetV2_1(
            num_classes=num_classes) if name == 'ShuffleNetV2_1' else\
        ShuffleNetV2_1_5(
            num_classes=num_classes) if name == 'ShuffleNetV2_1_5' else\
        ShuffleNetV2_2(
            num_classes=num_classes) if name == 'ShuffleNetV2_2' else\
        SqueezeNet_1(
            num_classes=num_classes) if name == 'SqueezeNet_1' else\
        SqueezeNet_1_1(
            num_classes=num_classes) if name == 'SqueezeNet_1_1' else\
        VGG11(
            num_classes=num_classes) if name == 'VGG11' else\
        VGG11_BN(
            num_classes=num_classes) if name == 'VGG11_BN' else\
        VGG13(
            num_classes=num_classes) if name == 'VGG13' else\
        VGG13_BN(
            num_classes=num_classes) if name == 'VGG13_BN' else\
        VGG16(
            num_classes=num_classes) if name == 'VGG16' else\
        VGG16_BN(
            num_classes=num_classes) if name == 'VGG16_BN' else\
        VGG19(
            num_classes=num_classes) if name == 'VGG19' else\
        VGG19_BN(
            num_classes=num_classes) if name == 'VGG19_BN' else \
        VGGCIFAR('VGG16',
                 num_classes=num_classes) if name == 'VGG16CIFAR' else \
        EfficientNetB4(
            num_classes=num_classes) if name == 'EfficientNetB4' else \
        EfficientNetB0CIFAR(
            num_classes=num_classes) if name == 'EfficientNetB0CIFAR' else\
        None
def main():
    """Main function."""
    # Load configuration file
    config = open("/workspaces/DD2424-project/configs/alexnet_config.json")

    # Create Dataset of tiny-imagenet from config
    dataset = Dataset(config)

    # Get training data and validation data
    ds_train = dataset.get_data("train")
    ds_test = dataset.get_data("val")

    config_alex = open("/workspaces/DD2424-project/configs/alexnet_config.json")

    # Train pure alexnet according to configuration file
    alex = AlexNet(config_alex)
    alex.set_train_data(ds_train)
    alex.set_test_data(ds_test)

    alex.generate_model()
    alex.summary()
    alex.start_train()
示例#24
0
pruned_weight_root = './AlexNet/pruned_weight_100k'
pretrain_model_path = './AlexNet/alexnet-owt-4df8aa71.pth'
n_validate_batch = 100  # Number of batches used for validation
validate_batch_size = 50  # Batch size of validation
prune_ratio = {
    'features.0': 80,
    'features.3': 35,
    'features.6': 35,
    'features.8': 35,
    'features.10': 35,
    'classifier.1': 10,
    'classifier.4': 10,
    'classifier.6': 35
}
# -------------------------------------------- User Config ------------------------------------
net = AlexNet()
net.load_state_dict(torch.load(pretrain_model_path))
param = net.state_dict()
total_nnz = 0
total_nelements = 0

for layer_name, CR in prune_ratio.items():

    if CR != 100:
        pruned_weight = np.load('%s/CR_%d/%s.weight.npy' %
                                (pruned_weight_root, CR, layer_name))
        pruned_bias = np.load('%s/CR_%d/%s.bias.npy' %
                              (pruned_weight_root, CR, layer_name))

        # Calculate sparsity
        total_nnz += np.count_nonzero(pruned_weight)
def main():

    global best_acc1
    best_acc1 = 0

    args = parse_option()

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # set the data loader
    train_folder = os.path.join(args.data_folder, 'train')
    val_folder = os.path.join(args.data_folder, 'val')

    logger = getLogger(args.save_folder)
    if args.dataset.startswith('imagenet') or args.dataset.startswith(
            'places'):
        image_size = 224
        crop_padding = 32
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        normalize = transforms.Normalize(mean=mean, std=std)
        if args.aug == 'NULL':
            train_transform = transforms.Compose([
                transforms.RandomResizedCrop(image_size,
                                             scale=(args.crop, 1.)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
        elif args.aug == 'CJ':
            train_transform = transforms.Compose([
                transforms.RandomResizedCrop(image_size,
                                             scale=(args.crop, 1.)),
                transforms.RandomGrayscale(p=0.2),
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
        else:
            raise NotImplemented('augmentation not supported: {}'.format(
                args.aug))

        val_transform = transforms.Compose([
            transforms.Resize(image_size + crop_padding),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            normalize,
        ])
        if args.dataset.startswith('imagenet'):
            train_dataset = datasets.ImageFolder(train_folder, train_transform)
            val_dataset = datasets.ImageFolder(
                val_folder,
                val_transform,
            )

        if args.dataset.startswith('places'):
            train_dataset = ImageList(
                '/data/trainvalsplit_places205/train_places205.csv',
                '/data/data/vision/torralba/deeplearning/images256',
                transform=train_transform,
                symbol_split=' ')
            val_dataset = ImageList(
                '/data/trainvalsplit_places205/val_places205.csv',
                '/data/data/vision/torralba/deeplearning/images256',
                transform=val_transform,
                symbol_split=' ')

        print(len(train_dataset))
        train_sampler = None

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            num_workers=args.n_workers,
            pin_memory=False,
            sampler=train_sampler)

        val_loader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.n_workers,
                                                 pin_memory=False)
    elif args.dataset.startswith('cifar'):
        train_loader, val_loader = cifar.get_linear_dataloader(args)
    elif args.dataset.startswith('svhn'):
        train_loader, val_loader = svhn.get_linear_dataloader(args)

    # create model and optimizer
    if args.model == 'alexnet':
        if args.layer == 6:
            args.layer = 5
        model = AlexNet(128)
        model = nn.DataParallel(model)
        classifier = LinearClassifierAlexNet(args.layer, args.n_label, 'avg')
    elif args.model == 'alexnet_cifar':
        if args.layer == 6:
            args.layer = 5
        model = AlexNet_cifar(128)
        model = nn.DataParallel(model)
        classifier = LinearClassifierAlexNet(args.layer,
                                             args.n_label,
                                             'avg',
                                             cifar=True)
    elif args.model == 'resnet50':
        model = resnet50(non_linear_head=False)
        model = nn.DataParallel(model)
        classifier = LinearClassifierResNet(args.layer, args.n_label, 'avg', 1)
    elif args.model == 'resnet18':
        model = resnet18()
        model = nn.DataParallel(model)
        classifier = LinearClassifierResNet(args.layer,
                                            args.n_label,
                                            'avg',
                                            1,
                                            bottleneck=False)
    elif args.model == 'resnet18_cifar':
        model = resnet18_cifar()
        model = nn.DataParallel(model)
        classifier = LinearClassifierResNet(args.layer,
                                            args.n_label,
                                            'avg',
                                            1,
                                            bottleneck=False)
    elif args.model == 'resnet50_cifar':
        model = resnet50_cifar()
        model = nn.DataParallel(model)
        classifier = LinearClassifierResNet(args.layer, args.n_label, 'avg', 1)
    elif args.model == 'resnet50x2':
        model = InsResNet50(width=2)
        classifier = LinearClassifierResNet(args.layer, args.n_label, 'avg', 2)
    elif args.model == 'resnet50x4':
        model = InsResNet50(width=4)
        classifier = LinearClassifierResNet(args.layer, args.n_label, 'avg', 4)
    elif args.model == 'shufflenet':
        model = shufflenet_v2_x1_0(num_classes=128, non_linear_head=False)
        model = nn.DataParallel(model)
        classifier = LinearClassifierResNet(args.layer, args.n_label, 'avg',
                                            0.5)
    else:
        raise NotImplementedError('model not supported {}'.format(args.model))

    print('==> loading pre-trained model')
    ckpt = torch.load(args.model_path)
    if not args.moco:
        model.load_state_dict(ckpt['state_dict'])
    else:
        try:
            state_dict = ckpt['state_dict']
            for k in list(state_dict.keys()):
                # retain only encoder_q up to before the embedding layer
                if k.startswith('module.encoder_q'
                                ) and not k.startswith('module.encoder_q.fc'):
                    # remove prefix
                    state_dict['module.' +
                               k[len("module.encoder_q."):]] = state_dict[k]
                # delete renamed or unused k
                del state_dict[k]
            model.load_state_dict(state_dict)
        except:
            pass
    print("==> loaded checkpoint '{}' (epoch {})".format(
        args.model_path, ckpt['epoch']))
    print('==> done')

    model = model.cuda()
    classifier = classifier.cuda()

    criterion = torch.nn.CrossEntropyLoss().cuda(args.gpu)

    if not args.adam:
        optimizer = torch.optim.SGD(classifier.parameters(),
                                    lr=args.learning_rate,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.Adam(classifier.parameters(),
                                     lr=args.learning_rate,
                                     betas=(args.beta1, args.beta2),
                                     weight_decay=args.weight_decay,
                                     eps=1e-8)

    model.eval()
    cudnn.benchmark = True

    # optionally resume from a checkpoint
    args.start_epoch = 1
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            # checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            classifier.load_state_dict(checkpoint['classifier'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            best_acc1 = checkpoint['best_acc1']
            print(best_acc1.item())
            best_acc1 = best_acc1.cuda()
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            if 'opt' in checkpoint.keys():
                # resume optimization hyper-parameters
                print('=> resume hyper parameters')
                if 'bn' in vars(checkpoint['opt']):
                    print('using bn: ', checkpoint['opt'].bn)
                if 'adam' in vars(checkpoint['opt']):
                    print('using adam: ', checkpoint['opt'].adam)
                #args.learning_rate = checkpoint['opt'].learning_rate
                # args.lr_decay_epochs = checkpoint['opt'].lr_decay_epochs
                args.lr_decay_rate = checkpoint['opt'].lr_decay_rate
                args.momentum = checkpoint['opt'].momentum
                args.weight_decay = checkpoint['opt'].weight_decay
                args.beta1 = checkpoint['opt'].beta1
                args.beta2 = checkpoint['opt'].beta2
            del checkpoint
            torch.cuda.empty_cache()
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # tensorboard
    tblogger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2)

    # routine
    best_acc = 0.0
    for epoch in range(args.start_epoch, args.epochs + 1):

        adjust_learning_rate(epoch, args, optimizer)
        print("==> training...")

        time1 = time.time()
        train_acc, train_acc5, train_loss = train(epoch, train_loader, model,
                                                  classifier, criterion,
                                                  optimizer, args)
        time2 = time.time()
        logging.info('train epoch {}, total time {:.2f}'.format(
            epoch, time2 - time1))

        logging.info(
            'Epoch: {}, lr:{} , train_loss: {:.4f}, train_acc: {:.4f}/{:.4f}'.
            format(epoch, optimizer.param_groups[0]['lr'], train_loss,
                   train_acc, train_acc5))

        tblogger.log_value('train_acc', train_acc, epoch)
        tblogger.log_value('train_acc5', train_acc5, epoch)
        tblogger.log_value('train_loss', train_loss, epoch)
        tblogger.log_value('learning_rate', optimizer.param_groups[0]['lr'],
                           epoch)

        test_acc, test_acc5, test_loss = validate(val_loader, model,
                                                  classifier, criterion, args)

        if test_acc >= best_acc:
            best_acc = test_acc

        logging.info(
            colorful(
                'Epoch: {}, val_loss: {:.4f}, val_acc: {:.4f}/{:.4f}, best_acc: {:.4f}'
                .format(epoch, test_loss, test_acc, test_acc5, best_acc)))
        tblogger.log_value('test_acc', test_acc, epoch)
        tblogger.log_value('test_acc5', test_acc5, epoch)
        tblogger.log_value('test_loss', test_loss, epoch)

        # save the best model
        if test_acc > best_acc1:
            best_acc1 = test_acc
            state = {
                'opt': args,
                'epoch': epoch,
                'classifier': classifier.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }
            save_name = '{}_layer{}.pth'.format(args.model, args.layer)
            save_name = os.path.join(args.save_folder, save_name)
            print('saving best model!')
            torch.save(state, save_name)

        # save model
        if epoch % args.save_freq == 0:
            print('==> Saving...')
            state = {
                'opt': args,
                'epoch': epoch,
                'classifier': classifier.state_dict(),
                'best_acc1': test_acc,
                'optimizer': optimizer.state_dict(),
            }
            save_name = 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)
            save_name = os.path.join(args.save_folder, save_name)
            print('saving regular model!')
            torch.save(state, save_name)

        # tensorboard logger
        pass
示例#26
0
文件: main.py 项目: jozhang97/WaveApp
def main():
  # Init logger
  if not os.path.isdir(args.save_path):
    os.makedirs(args.save_path)
  log = open(os.path.join(args.save_path, 'log_seed_{}.txt'.format(args.manualSeed)), 'w')
  print_log('save path : {}'.format(args.save_path), log)
  state = {k: v for k, v in args._get_kwargs()}
  print_log(state, log)
  print_log("Random Seed: {}".format(args.manualSeed), log)
  print_log("python version : {}".format(sys.version.replace('\n', ' ')), log)
  print_log("torch  version : {}".format(torch.__version__), log)
  print_log("cudnn  version : {}".format(torch.backends.cudnn.version()), log)

  # Data loading code
  # Any other preprocessings? http://pytorch.org/audio/transforms.html
  sample_length = 10000
  scale = transforms.Scale()
  padtrim = transforms.PadTrim(sample_length)
  downmix = transforms.DownmixMono()
  transforms_audio = transforms.Compose([
    scale, padtrim, downmix
  ])

  if not os.path.isdir(args.data_path):
    os.makedirs(args.data_path)
  train_dir = os.path.join(args.data_path, 'train')
  val_dir = os.path.join(args.data_path, 'val')

  #Choose dataset to use
  if args.dataset == 'arctic':
    # TODO No ImageFolder equivalent for audio. Need to create a Dataset manually
    train_dataset = Arctic(train_dir, transform=transforms_audio, download=True)
    val_dataset = Arctic(val_dir, transform=transforms_audio, download=True)
    num_classes = 4
  elif args.dataset == 'vctk':
    train_dataset = dset.VCTK(train_dir, transform=transforms_audio, download=True)
    val_dataset = dset.VCTK(val_dir, transform=transforms_audio, download=True)
    num_classes = 10
  elif args.dataset == 'yesno':
    train_dataset = dset.YESNO(train_dir, transform=transforms_audio, download=True)
    val_dataset = dset.YESNO(val_dir, transform=transforms_audio, download=True)
    num_classes = 2
  else:
    assert False, 'Dataset is incorrect'

  train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    shuffle=True,
    num_workers=args.workers,
    # pin_memory=True, # What is this?
    # sampler=None     # What is this?
  )
  val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=args.batch_size, shuffle=False,
    num_workers=args.workers, pin_memory=True)


  #Feed in respective model file to pass into model (alexnet.py)
  print_log("=> creating model '{}'".format(args.arch), log)
  # Init model, criterion, and optimizer
  # net = models.__dict__[args.arch](num_classes)
  net = AlexNet(num_classes)
  #
  print_log("=> network :\n {}".format(net), log)

  # net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

  # define loss function (criterion) and optimizer
  criterion = torch.nn.CrossEntropyLoss()

  # Define stochastic gradient descent as optimizer (run backprop on random small batch)
  optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
                weight_decay=state['decay'], nesterov=True)

  #Sets use for GPU if available
  if args.use_cuda:
    net.cuda()
    criterion.cuda()

  recorder = RecorderMeter(args.epochs)
  # optionally resume from a checkpoint
  # Need same python vresion that the resume was in 
  if args.resume:
    if os.path.isfile(args.resume):
      print_log("=> loading checkpoint '{}'".format(args.resume), log)
      if args.ngpu == 0:
        checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage)
      else:
        checkpoint = torch.load(args.resume)

      recorder = checkpoint['recorder']
      args.start_epoch = checkpoint['epoch']
      net.load_state_dict(checkpoint['state_dict'])
      optimizer.load_state_dict(checkpoint['optimizer'])
      print_log("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch']), log)
    else:
      print_log("=> no checkpoint found at '{}'".format(args.resume), log)
  else:
    print_log("=> do not use any checkpoint for {} model".format(args.arch), log)

  if args.evaluate:
    validate(val_loader, net, criterion, 0, log, val_dataset)
    return

  # Main loop
  start_time = time.time()
  epoch_time = AverageMeter()

  # Training occurs here
  for epoch in range(args.start_epoch, args.epochs):
    current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule)

    need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs-epoch))
    need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)

    print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

    print("One epoch")
    # train for one epoch
    # Call to train (note that our previous net is passed into the model argument)
    train_acc, train_los = train(train_loader, net, criterion, optimizer, epoch, log, train_dataset)

    # evaluate on validation set
    #val_acc,   val_los   = extract_features(test_loader, net, criterion, log)
    val_acc,   val_los   = validate(val_loader, net, criterion, epoch, log, val_dataset)
    is_best = recorder.update(epoch, train_los, train_acc, val_los, val_acc)

    save_checkpoint({
      'epoch': epoch + 1,
      'arch': args.arch,
      'state_dict': net.state_dict(),
      'recorder': recorder,
      'optimizer' : optimizer.state_dict(),
    }, is_best, args.save_path, 'checkpoint.pth.tar')

    # measure elapsed time
    epoch_time.update(time.time() - start_time)
    start_time = time.time()
    recorder.plot_curve( os.path.join(args.save_path, 'curve.png') )

  log.close()
def jobSetup():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    exit = False  # Exit jobsetup Boolean
    joblist = []
    while (not exit):
        # These booleans control the state of the menu
        SessionTypeBool = True
        ModelTypeBool = True
        EpochBool = True
        TrainBatchBool = True
        OptimBool = True
        TestBatchBool = True
        jobBool = True

        #--------------------------------------Model Selection--------------------------------------#
        while (ModelTypeBool):
            modeltype = input(
                " a.Alexnet \n b.VGG16  \n c.ResNext  \n d.VGGv2\n   >")
            if (modeltype != 'a' and modeltype != 'b' and modeltype != 'c'
                    and modeltype != 'd'):
                print("Please input a valid model input")
                ModelTypeBool = True

            if (modeltype == 'a'):
                model = AlexNet()
                modeldict = 'Alexnet-model.pt'
                modelname = "Alexnet"
                valtrain = 32
                valtest = 136
                optimizer = optim.Adam(model.parameters(), lr=0.001)
                ModelTypeBool = False

            elif (modeltype == 'b'):
                model = VGG16()
                modeldict = 'VGG16-model.pt'
                modelname = "VGG16"
                valtrain = 32
                valtest = 136
                optimizer = optim.SGD(model.parameters(), lr=0.001)
                ModelTypeBool = False

            elif (modeltype == 'c'):
                model = resnext50_32x4d()
                modeldict = 'ResNext50-model.pt'
                modelname = "ResNext50"
                valtrain = 32
                valtest = 136
                optimizer = optim.Adam(model.parameters(), lr=0.001)
                ModelTypeBool = False

            elif (modeltype == 'd'):
                model = VGG_v2()
                modeldict = 'VGGv2-model.pt'
                modelname = "VGGv2"
                valtrain = 32
                valtest = 136
                optimizer = optim.Adam(model.parameters(), lr=0.001)
                ModelTypeBool = False

        print(modelname + ": chosen")

        #------------------------------------Session Selection--------------------------------------#
        while (SessionTypeBool):
            sessiontype = input(
                " a.Start Training a new model \n b.Test the model \n   >")
            if (sessiontype != 'a' and sessiontype != 'b'
                    and sessiontype != 'c'):
                print("Please input a valid session input")
                SessionTypeBool = True
            if (sessiontype == 'a'):
                SessionTypeBool = False
                print("From Stratch: chosen")
            elif (sessiontype == 'b'):
                SessionTypeBool = False
                TrainBatchBool = False
                OptimBool = False
                EpochBool = False
                valtrain = 1
                epochval = 1
                print("Testing: chosen")
        #UNCOMMENT FOR CONTINUE TRAINING OPTION Uncomment and use at your own risk!
            """
         elif (sessiontype == 'c'):
            SessionTypeBool = False
            print ("Testing: chosen")
         """
        #------------------------------------Epoch Selection--------------------------------------#
        while (EpochBool):
            epoch = input(" Number of Epochs:   ")
            try:
                epochval = int(epoch)
                print(f'\nEpochs chosen: {epochval}')
                EpochBool = False
            except ValueError:
                print("Please input a valid Epochs input")
                EpochBool = True

        # This section is DEVELOPER USE ONLY. We do not want the user to change the training or test batch numbers
        # as this can lead to CUDA out of memory errors. Uncomment and use at your own risk!
        """
      #------------------------------------Optimiser Selection---------------------------------#
      while (OptimBool):
         optimiseinput = input(" Optimizer (Debug): \n a.Adam \n b.SGD  \n   >") 
         if (optimiseinput != 'a' and optimiseinput != 'b'):
            print ("Please input a valid Optimizer input")
            OptimBool = True
         if (optimiseinput == 'a'):  
            optimizer = optim.Adam(model.parameters(), lr=0.001)
            print ("Adam chosen")
            OptimBool = False
         elif (optimiseinput == 'b'):
            optimizer = optim.SGD(model.parameters(), lr=0.001)
            print ("SGD chosen")
            OptimBool = False
      #------------------------------------Batch Selection---------------------------------#
      while (TrainBatchBool):
         trainbatch = input(" Number of train batchs (Debug):   ")
         try:
            valtrain = int(trainbatch)
            print(f'\ntraining batchs chosen: {valtrain}')
            TrainBatchBool = False
         except ValueError:
            print ("Please input a valid batchs input")
            TrainBatchBool = True

      while (TestBatchBool):
         testbatch = input(" Number of test batchs (Debug):   ")
         try:
            valtest = int(testbatch)
            print(f'\ntest batchs chosen: {valtest}')
            TestBatchBool = False
         except ValueError:
            print ("Please input a valid batchs input")
            TestBatchBool = True
      """
        #------------------------------------Job Menu---------------------------------------#
        job = jobclass(sessiontype, model, modeldict, optimizer, epochval,
                       device, valtrain, valtest, modelname)
        joblist.append(job)

        while (jobBool):
            finish = input(
                " Would you like to run another Model after? y/n:   ")
            if (finish != 'y' and finish != 'n'):
                print("Please input a valid job input")
                jobBool = True
            if (finish == 'y'):
                jobBool = False
                print("Add another job")

            if (finish == 'n'):
                jobBool = False
                exit = True
                print("Jobs Executing")
    return joblist
def train(data_train, data_val, num_classes, num_epoch, milestones):
    model = AlexNet(num_classes, pretrain=False)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    lr_scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.1)

    since = time.time()
    best_acc = 0
    best = 0
    for epoch in range(num_epoch):
        print('Epoch {}/{}'.format(epoch + 1, num_epoch))
        print('-' * 10)


        # Iterate over data.
        running_loss = 0.0
        running_corrects = 0
        model.train()
        with torch.set_grad_enabled(True):
            for i, (inputs, labels) in enumerate(data_train):
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                running_corrects += torch.sum(preds == labels.data) * 1. / inputs.size(0)
                print("\rIteration: {}/{}, Loss: {}.".format(i + 1, len(data_train), loss.item()), end="")

                sys.stdout.flush()

        avg_loss = running_loss / len(data_train)
        t_acc = running_corrects.double() / len(data_train)

        running_loss = 0.0
        running_corrects = 0
        model.eval()
        with torch.set_grad_enabled(False):
            for i, (inputs, labels) in enumerate(data_val):
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                running_loss += loss.item()
                running_corrects += torch.sum(preds == labels.data) * 1. / inputs.size(0)

        val_loss = running_loss / len(data_val)
        val_acc = running_corrects.double() / len(data_val)

        print()
        print('Train Loss: {:.4f} Acc: {:.4f}'.format(avg_loss, t_acc))
        print('Val Loss: {:.4f} Acc: {:.4f}'.format(val_loss, val_acc))
        print('lr rate: {:.6f}'.format(optimizer.param_groups[0]['lr']))
        print()

        if val_acc > best_acc:
            best_acc = val_acc
            best = epoch + 1

        lr_scheduler.step()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best Validation Accuracy: {}, Epoch: {}'.format(best_acc, best))

    return model
示例#29
0
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--classify',
                        help='Predict the class of an input image',
                        type=str)
    parser.add_argument('--test',
                        help='Evaluate accuracy on the test set',
                        action='store_true')
    parser.add_argument('--validation',
                        help='Evaluate accuracy on the validation set',
                        action='store_true')
    args = parser.parse_args()

    cfg = Configuration()
    net = AlexNet(cfg, training=False)

    testset = ImageNetDataset(cfg, 'test')

    if tfe.num_gpus() > 2:
        # set 2 to 0 if you want to run on the gpu
        # but currently running on gpu is impossible
        # because tf.in_top_k does not have a cuda implementation
        with tf.device('/gpu:0'):
            tester = Tester(cfg, net, testset)

            if args.classify:
                tester.classify_image(args.classify)
            elif args.validation:
                tester.test('validation')
            else:
示例#30
0
def main():
    parser = argparse.ArgumentParser(description='Process some integers.')
    parser.add_argument('--weight_path',
                        type=str,
                        default="results/nakamoto_dataaug_alexnet.pkl",
                        help='weight data path')
    parser.add_argument('--reshape_size',
                        type=int,
                        default=227,
                        help='reshape size')
    parser.add_argument('--model',
                        type=str,
                        default="alexnet",
                        help='alexnet or vgg11 or deepsimplenet')
    args = parser.parse_args()
    print("judge start")

    #中本
    categories = [
        "gomoku_miso", "gomoku_moko", "hiyasi_gomoku", "hiyasi_miso",
        "hokkyoku", "hokkyoku_yasai", "miso", "miso_ran", "moko", "moko_ran",
        "sio"
    ]

    categories_ja = [
        "五目味噌タンメン", "五目蒙古タンメン", "冷やし五目", "冷やし味噌野菜", "北極", "北極野菜", "味噌タンメン",
        "味噌卵麺", "蒙古タンメン", "蒙古卵麺", "塩タンメン"
    ]

    #pokemon
    # categories = ["pikachu","Eevee","Numera"]

    # root_dir = "dataset_nakamoto_inf"
    judge_data = ["judge_data"]

    prob = []
    # 画像データ用配列
    X = []
    # ラベルデータ用配列
    Y = []

    if args.model == "alexnet":
        network = AlexNet(input_dim=(3, args.reshape_size, args.reshape_size),
                          output_size=len(categories))
    elif args.model == "vgg11":
        network = AlexNet(input_dim=(3, args.reshape_size, args.reshape_size),
                          output_size=len(categories))
    elif args.model == "deepsimplenet":
        network = AlexNet(input_dim=(3, args.reshape_size, args.reshape_size),
                          output_size=len(categories))

    # network.load_params("nakamotoinf_alexnet.pkl")
    network.load_params(args.weight_path)

    #全データ格納用配列
    allfiles = []

    #カテゴリ配列の各値と、それに対応するidxを認識し、全データをallfilesにまとめる
    for idx, item in enumerate(judge_data):
        image_dir = item
        files = glob.glob(image_dir + "/*")
        for f in files:
            allfiles.append((idx, f))

    X_train, y_train = make_sample(allfiles, args.reshape_size)
    X_train = X_train.astype(np.float32)
    X_train /= 255.0
    x_train = X_train.transpose(0, 3, 1, 2)

    print(Fore.RED)

    y = network.predict(x_train)

    for i in range(len(y)):
        print("")
        print("")
        print("")
        answer = y[i].argsort()[::-1]
        print("予測精度:", int(calc_prob(y[i]) * 100), "%")
        print(
            nakamoto_review(categories[answer[0]], categories_ja[answer[0]],
                            categories_ja[answer[1]],
                            categories_ja[answer[2]]))

    print(Style.RESET_ALL)