def main():
    parser = train_args.get_args()
    cli_args = parser.parse_args()

    if cli_args.arch != 'alexnet':
        print('Currently, we support only AlexNet.')
        exit(1)

    use_cuda = False
    epochs = cli_args.epochs
    checkpoint_name = 'checkpoint.pt'

    if cli_args.save_dir:
        save_dir = cli_args.save_dir

    if cli_args.save_name:
        save_name = cli_args.save_name

    if save_dir and save_name:
        checkpoint_name = f'{cli_args.save_dir}/{cli_args.save_name}.pt'

    # check if CUDA is available and if we want to use it
    if cli_args.use_gpu and torch.cuda.is_available():
        use_cuda = True
    else:
        print("GPU is not available. Using CPU.")

    hidden_units = cli_args.hidden_units

    # check for data directory
    if not os.path.isdir(cli_args.data_directory):
        print(f'Data directory {cli_args.data_directory} was not found.')
        exit(1)

    # check for save directory
    if not os.path.isdir(cli_args.save_dir):
        print(f'Directory {cli_args.save_dir} does not exist. Creating...')
        os.makedirs(cli_args.save_dir)

    # load the directory
    train_dir = cli_args.data_directory
    valid_dir = 'flowers/valid'

    train_transform = transforms.Compose([
        transforms.RandomRotation(15),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    valid_transform = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    data_transforms = {'train': train_transform, 'valid': valid_transform}

    train_dataset = ImageFolder(train_dir, transform=train_transform)
    valid_dataset = ImageFolder(valid_dir, transform=valid_transform)

    image_datasets = {'train': train_dataset, 'valid': valid_dataset}

    batch_size = 20
    num_workers = 0

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               num_workers=num_workers,
                                               shuffle=True)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=batch_size,
                                               num_workers=num_workers,
                                               shuffle=False)

    dataloaders = {'train': train_loader, 'valid': valid_loader}

    with open('cat_to_name.json', 'r') as f:
        cat_to_name = json.load(f)

    model_transfer, criterion_transfer, optimizer_transfer = get_model(
        use_cuda=use_cuda, hidden_units=hidden_units)
    model_transfer = train(dataloaders, model_transfer, optimizer_transfer,
                           criterion_transfer, checkpoint_name, epochs,
                           use_cuda, train_dataset)
Exemplo n.º 2
0
if os.path.exists("net.pkl"):
    pkl = torch.load("net.pkl")
    net = pkl.get("model")
    sepoch = pkl.get("epoch")
else:
    net = AlexNet().cuda()
    sepoch = 1

criterion = nn.CrossEntropyLoss().cuda()
optimizer = optim.SGD(params=net.parameters(), lr=1e-2, momentum=9e-1)

data_loader = DataLoader(dataset=ImageFolder(
    "data/train",
    transform=transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])),
                         batch_size=50,
                         shuffle=True)


def adjust_learning_rate(epoch):
    lr = 1e-2 * 1e-1**(epoch // 20)
    for group in optimizer.param_groups:
        group['lr'] = lr


def train():
    net.train()
Exemplo n.º 3
0
def get_train_valid_loader(data_dir,
                           batch_size,
                           valid_size=0.2,
                           shuffle=True,
                           show_sample=False,
                           num_workers=0,
                           pin_memory=False):
    """
    Utility function for loading and returning train and valid
    multi-process iterators over the CIFAR-10 dataset. A sample
    9x9 grid of the images can be optionally displayed.
    If using CUDA, num_workers should be set to 1 and pin_memory to True.
    Params
    ------
    - data_dir: path directory to the dataset.
    - batch_size: how many samples per batch to load.
    - augment: whether to apply the data augmentation scheme
      mentioned in the paper. Only applied on the train split.
    - random_seed: fix seed for reproducibility.
    - valid_size: percentage split of the training set used for
      the validation set. Should be a float in the range [0, 1].
    - shuffle: whether to shuffle the train/validation indices.
    - show_sample: plot 9x9 sample grid of the dataset.
    - num_workers: number of subprocesses to use when loading the dataset.
    - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
      True if using GPU.
    Returns
    -------
    - train_loader: training set iterator.
    - valid_loader: validation set iterator.
    """
    error_msg = "[!] valid_size should be in the range [0, 1]."
    assert ((valid_size >= 0) and (valid_size <= 1)), error_msg

    normalize = transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.ToTensor(),  # 将图片转换为Tensor,归一化至[0,1]
        normalize
    ])

    # load the dataset
    train_dataset = ImageFolder(data_dir,transform=transform)
    valid_dataset = ImageFolder(data_dir,transform=transform)

    num_train = len(train_dataset)
    #num_valid = len(valid_dataset)
    indices = list(range(num_train))
    split = int(np.floor(valid_size * num_train))

    if shuffle:
        #np.random.seed(random_seed)
        np.random.shuffle(indices)

    train_idx, valid_idx = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, sampler=train_sampler,
        num_workers=num_workers, pin_memory=pin_memory,
    )
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=batch_size, sampler=valid_sampler,
        num_workers=num_workers, pin_memory=pin_memory,
    )

    # visualize some images
    if show_sample:
        sample_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=9, shuffle=shuffle,
            num_workers=num_workers, pin_memory=pin_memory,
        )
        data_iter = iter(sample_loader)
        images, labels = data_iter.next()
        X = images.numpy().transpose([0, 2, 3, 1])
        plot_images(X, labels)

    return (train_loader, valid_loader, num_train*(1-valid_size), num_train*valid_size)
Exemplo n.º 4
0
def create_app(node_id, debug=False, database_url=None, data_dir: str = None):
    """ Create / Configure flask socket application instance.
        
        Args:
            node_id (str) : ID of Grid Node.
            debug (bool) : debug flag.
            test_config (bool) : Mock database environment.
        Returns:
            app : Flask application instance.
    """
    app = Flask(__name__)
    app.debug = debug

    app.config["SECRET_KEY"] = "justasecretkeythatishouldputhere"

    # Enable persistent mode
    # Overwrite syft.object_storage methods to work in a persistent way
    # Persist models / tensors
    if database_url:
        app.config["REDISCLOUD_URL"] = database_url
        from .main.persistence import database, object_storage

        db_instance = database.set_db_instance(database_url)
        object_storage.set_persistent_mode(db_instance)

    from .main import html, ws, hook, local_worker, auth

    # Global socket handler
    sockets = Sockets(app)

    # set_node_id(id)
    local_worker.id = node_id
    hook.local_worker._known_workers[node_id] = local_worker
    local_worker.add_worker(hook.local_worker)

    # add data
    if data_dir:
        print("register data")
        if "mnist" in data_dir.lower():
            dataset = MNIST(
                root="./data",
                train=True,
                download=True,
                transform=transforms.Compose(
                    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
                ),
            )
            if node_id in KEEP_LABELS_DICT:
                indices = np.isin(dataset.targets, KEEP_LABELS_DICT[node_id]).astype(
                    "uint8"
                )
                selected_data = (
                    torch.native_masked_select(  # pylint:disable=no-member
                        dataset.data.transpose(0, 2),
                        torch.tensor(indices),  # pylint:disable=not-callable
                    )
                    .view(28, 28, -1)
                    .transpose(2, 0)
                )
                selected_targets = torch.native_masked_select(  # pylint:disable=no-member
                    dataset.targets,
                    torch.tensor(indices),  # pylint:disable=not-callable
                )
            """ dataset = sy.BaseDataset(
                    data=selected_data,
                    targets=selected_targets,
                    transform=dataset.transform,
                )
            dataset_name = "mnist"
            """
        else:

            train_tf = [
                transforms.RandomVerticalFlip(p=0.5),
                transforms.RandomAffine(
                    degrees=30,
                    translate=(0, 0),
                    scale=(0.85, 1.15),
                    shear=10,
                    #    fillcolor=0.0,
                ),
                transforms.Resize(224),
                transforms.RandomCrop(224),
                transforms.ToTensor(),
                transforms.Normalize((0.57282609,), (0.17427578,)),
                # transforms.RandomApply([AddGaussianNoise(mean=0.0, std=0.05)], p=0.5),
            ]
            """train_tf.append(
                transforms.Lambda(
                    lambda x: torch.repeat_interleave(  # pylint: disable=no-member
                        x, 3, dim=0
                    )
                )
            )"""
            target_dict_pneumonia = {0: 1, 1: 0, 2: 2}
            dataset = ImageFolder(
                data_dir,
                transform=transforms.Compose(train_tf),
                target_transform=lambda x: target_dict_pneumonia[x],
            )
            data, targets = [], []
            for d, t in tqdm(dataset, total=len(dataset)):
                data.append(d)
                targets.append(t)
            selected_data = torch.stack(data)  # pylint:disable=no-member
            selected_targets = torch.from_numpy(np.array(targets))  # pylint:disable=no-member
            #dataset = sy.BaseDataset(data=data, targets=targets)
            """dataset = PPPP(
                "data/Labels.csv",
                train=True,
                transform=transforms.Compose(train_tf),
                seed=1
            )"""
            dataset_name = "pneumonia"

        local_worker.register_obj(selected_data, "data")
        local_worker.register_obj(selected_targets, "targets")

        print("registered data")

    # Register app blueprints
    app.register_blueprint(html, url_prefix=r"/")
    sockets.register_blueprint(ws, url_prefix=r"/")

    # Set Authentication configs
    app = auth.set_auth_configs(app)

    return app
Exemplo n.º 5
0
classes1 = os.listdir(data_dir + "/test")
print(classes1)

airplane_files = os.listdir(data_dir + "/train/airplane")
print("Number of training example for airplane: ", len(airplane_files))
print(airplane_files[:5])

ship_test_file = os.listdir(data_dir + "/test/ship")
print("Number of test example for test: ", len(ship_test_file))
print(ship_test_file[:5])

from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor

dataset = ImageFolder(data_dir + '/train', transform=ToTensor())

# printing the sample element from the training dataset.each element is a tupple,containing image tensor and label)
img, label = dataset[3]
print(img.shape,label)
img

print(dataset.classes)

# we can view the image using matplotlib.but we need to  change the tensor dimension to (32,32,3)
import matplotlib.pyplot as plt
def show_example(img, label):
  print('label: ',dataset.classes[label], "("+str(label)+")")
  plt.imshow(img.permute(1,2,0))

show_example(*dataset[0])
Exemplo n.º 6
0
if __name__ == '__main__':
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = torchvision.models.resnet18(pretrained=True)
    model.fc = nn.Linear(512, 5)
    model = model.to(device)
    # summary
    from torchsummary import summary
    summary(model, (3, 128, 64))
    model.train(mode=True)
    # print(model)
    transform = transforms.Compose([
        transforms.Resize((128, 64)),
        transforms.ToTensor(),
    ])
    dataset = ImageFolder('train', transform=transform)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=128,
                                             shuffle=True)

    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=4,
                                                gamma=0.85)

    # 迭代epoch
    acc_list = []
    for epoch in trange(10):

        correct = 0
Exemplo n.º 7
0
import torchvision.transforms as transforms
# https://towardsdatascience.com/a-beginners-tutorial-on-building-an-ai-image-classifier-using-pytorch-6f85cb69cba7
transformations = transforms.Compose([
    transforms.Resize(800),
    transforms.CenterCrop(800),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# ----------------------------------------------------------------------------------------------------------------
# Load files into pytorch
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor

print()
print('This is train_data class label mapping:')
train_data = ImageFolder(os.path.join(base_path, 'train_data'),
                         transform=transformations)
print(train_data.class_to_idx)
print('The number of paintings in train dataset: ')
print(len(train_data))

print()
print('This is test_data class label mapping:')
test_data = ImageFolder(os.path.join(base_path, 'test_data'),
                        transform=transformations)
print(test_data.class_to_idx)
print('The number of paintings in test dataset: ')
print(len(test_data))

# https://discuss.pytorch.org/t/questions-about-imagefolder/774/6
# https://github.com/amir-jafari/Deep-Learning/blob/master/Pytorch_/6-Conv_Mnist/Conv_Mnist_gpu.py
from torch.utils.data import DataLoader
Exemplo n.º 8
0
 def ds_clf(self):
     t = base_transform()
     return ImageFolder(root="/imagenet224/train", transform=t)
Exemplo n.º 9
0
 def ds_test(self):
     t = base_transform()
     return ImageFolder(root="/imagenet224/val", transform=t)
def load_data(opt):
    """ Load Data

    Args:
        opt ([type]): Argument Parser

    Raises:
        IOError: Cannot Load Dataset

    Returns:
        [type]: dataloader
    """

    ##
    # LOAD DATA SET
    if opt.dataroot == '':
        opt.dataroot = './data/{}'.format(opt.dataset)

    ## CIFAR
    if opt.dataset in ['cifar10']:
        transform = transforms.Compose([
            transforms.Resize(opt.isize),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        train_ds = CIFAR10(root='./data',
                           train=True,
                           download=True,
                           transform=transform)
        valid_ds = CIFAR10(root='./data',
                           train=False,
                           download=True,
                           transform=transform)
        train_ds, valid_ds = get_cifar_anomaly_dataset(train_ds, valid_ds,
                                                       int(opt.abnormal_class))

    ## MNIST
    elif opt.dataset in ['mnist']:
        transform = transforms.Compose([
            transforms.Resize(opt.isize),
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])

        train_ds = MNIST(root='./data',
                         train=True,
                         download=True,
                         transform=transform)
        valid_ds = MNIST(root='./data',
                         train=False,
                         download=True,
                         transform=transform)
        train_ds, valid_ds = get_mnist_anomaly_dataset(train_ds, valid_ds,
                                                       int(opt.abnormal_class))

    # FOLDER
    elif opt.dataset in ['OCT']:
        # TODO: fix the OCT dataset into the dataloader and return
        class OverLapCrop():
            def __init__(self, img_size):
                self.img_size = img_size

            def __call__(self, x):
                ret = []
                for i in range(256 // opt.isize):
                    for j in range(256 // opt.isize):
                        ret.append(x[:, i * opt.isize:(i + 1) * opt.isize,
                                     j * opt.isize:(j + 1) * opt.isize])
                return ret

        splits = ['train', 'test']
        drop_last_batch = {'train': True, 'test': True}
        shuffle = {'train': True, 'test': False}
        train_transform = transforms.Compose([
            transforms.Grayscale(),
            transforms.ColorJitter(brightness=0.1,
                                   contrast=0.1,
                                   saturation=0.1,
                                   hue=0.1),
            transforms.Resize([256, 256]),
            transforms.RandomCrop(opt.isize),  #
            transforms.ToTensor(),
        ])

        test_transform = transforms.Compose([
            transforms.Grayscale(),
            transforms.Resize([256, 256]),
            transforms.ToTensor(),
            OverLapCrop(opt.isize),
            transforms.Lambda(lambda crops: torch.stack(crops)),
        ])

        dataset = {
            'train':
            ImageFolder(os.path.join(opt.dataroot, 'train'), train_transform),
            'test':
            ImageFolder(os.path.join(opt.dataroot, 'test'), test_transform),
        }

        train_dl = DataLoader(dataset=dataset['train'],
                              batch_size=opt.batchsize,
                              shuffle=shuffle['train'],
                              num_workers=int(opt.n_cpu),
                              drop_last=drop_last_batch['train'],
                              worker_init_fn=(None if 42 == -1 else
                                              lambda x: np.random.seed(42)))
        valid_dl = DataLoader(dataset=dataset['test'],
                              batch_size=opt.batchsize,
                              shuffle=shuffle['test'],
                              num_workers=int(opt.n_cpu),
                              drop_last=drop_last_batch['test'],
                              worker_init_fn=(None if 42 == -1 else
                                              lambda x: np.random.seed(42)))
        return Data(train_dl, valid_dl)

    elif opt.dataset in ['KDD99']:
        train_ds = KDD_dataset(opt, mode='train')
        valid_ds = KDD_dataset(opt, mode='test')

    else:
        raise NotImplementedError

    ## DATALOADER
    train_dl = DataLoader(dataset=train_ds,
                          batch_size=opt.batchsize,
                          shuffle=True,
                          drop_last=True)
    valid_dl = DataLoader(dataset=valid_ds,
                          batch_size=opt.batchsize,
                          shuffle=False,
                          drop_last=False)

    return Data(train_dl, valid_dl)
Exemplo n.º 11
0
 def ds_train(self):
     t = MultiSample(aug_transform(), n=self.aug_cfg.num_samples)
     return ImageFolder(root="/imagenet/train", transform=t)
Exemplo n.º 12
0
    def __init__(self,
                 config,
                 config_path,
                 name,
                 dataset,
                 split,
                 model,
                 pretrain: bool,
                 optimizer,
                 lr: float = DEFAULT_LR,
                 momentum: float = DEFAULT_MOMENTUM,
                 weight_decay: float = DEFAULT_WEIGHT_DECAY,
                 start_epoch: int = BaseTrainer.DEFAULT_START_EPOCH):
        super(Trainer, self).__init__(name, dataset, split, model, optimizer,
                                      start_epoch)

        if not lr > 0:
            raise ValueError(
                value_error_msg('lr', lr, 'lr > 0', Trainer.DEFAULT_LR))

        if not momentum >= 0:
            raise ValueError(
                value_error_msg('momentum', momentum, 'momentum >= 0',
                                Trainer.DEFAULT_MOMENTUM))

        if not weight_decay >= 0:
            raise ValueError(
                value_error_msg('weight_decay', weight_decay,
                                'weight_decay >= 0', Trainer.DEFAULT_MOMENTUM))

        self.config = config
        self.model_path = format_path(
            self.config[self.name.value]['model_format'], self.name.value,
            self.config['Default']['delimiter'])

        if self.split == Trainer.Split.TRAIN_VAL:
            self.phase = ['train', 'val']
        elif self.split == Trainer.Split.TRAIN_ONLY:
            self.phase = ['train']
        else:
            raise ValueError(
                value_error_msg('split', split, BaseTrainer.SPLIT_LIST))

        if self.name == Trainer.Name.MARKET1501:
            transform_train_list = [
                # transforms.RandomResizedCrop(size=128, scale=(0.75,1.0), ratio=(0.75,1.3333), interpolation=3),
                transforms.Resize((256, 128), interpolation=3),
                transforms.Pad(10),
                transforms.RandomCrop((256, 128)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
            ]

            transform_val_list = [
                transforms.Resize(size=(256, 128),
                                  interpolation=3),  # Image.BICUBIC
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
            ]

            data_transforms = {
                'train': transforms.Compose(transform_train_list),
                'val': transforms.Compose(transform_val_list),
            }
        else:
            raise ValueError(
                value_error_msg('name', self.name, Trainer.NAME_LIST))

        # dataset declaration
        self.dataset = {}
        self.dataset_sizes = {}

        # dataset loading
        for phase in self.phase:
            folder_name = phase
            if self.split == Trainer.Split.TRAIN_ONLY:
                folder_name = 'total_' + folder_name
            self.dataset[phase] = ImageFolder(
                join(self.config[self.name.value]['dataset_dir'], folder_name),
                data_transforms[phase])
            self.dataset_sizes[phase] = len(self.dataset[phase])

        # record train_class num on setting files
        model_name = self.model.value
        train_class = len(self.dataset['train'].classes)
        config[self.name.value]['train_class'] = str(train_class)
        with open(config_path, 'w+') as file:
            config.write(file)

        # initialize model weights
        if self.model == Trainer.Model.RESNET50:
            self.model = ResNet50(self.config,
                                  train_class,
                                  pretrained=pretrain)
            if self.start_epoch > 0:
                load_model(
                    self.model, self.config[self.name.value]['model_format'] %
                    (model_name, self.start_epoch))
        # else:
        #     raise ValueError(value_error_msg('model', model, Trainer.MODEL_LIST))

        self.suffix = 'pretrain' if pretrain else 'no_pretrain'
        self.train_path = self.config[
            self.name.value]['train_path'] % self.suffix

        # use different settings for different params in model when using optimizers
        ignored_params = list(map(id, self.model.final_block.parameters()))
        base_params = filter(lambda p: id(p) not in ignored_params,
                             self.model.parameters())

        if self.optimizer == BaseTrainer.Optimizer.SGD:
            self.optimizer = optim.SGD(
                [{
                    'params': base_params,
                    'lr': 0.1 * lr
                }, {
                    'params': self.model.final_block.parameters(),
                    'lr': lr
                }],
                weight_decay=weight_decay,
                momentum=momentum,
                nesterov=True)
        # else:
        #     raise ValueError(value_error_msg('optimizer', optimizer, Trainer.OPTIMIZER_LIST))
        self.criterion = nn.CrossEntropyLoss()
ckpt = torch.load('./models/15.pth')

net_ae = torch.nn.DataParallel(net_ae)
net_ig = torch.nn.DataParallel(net_ig)

net_ae.load_state_dict(ckpt['ae'])
net_ig.load_state_dict(ckpt['ig'])

net_ae.to(device)
net_ig.to(device)

net_ae.eval()

batch_size = 8
dataset = ImageFolder(root='../artland_1/data/rgb_select/',
                      transform=trans_maker(size=512))
dataset_pr = ImageFolder(root='../../data/unsplash/',
                         transform=trans_maker(size=512))

dataloader = iter(DataLoader(dataset, batch_size, \
        sampler=InfiniteSamplerWrapper(dataset), num_workers=4, pin_memory=True))

dataloader_pr = iter(DataLoader(dataset_pr, batch_size, \
        sampler=InfiniteSamplerWrapper(dataset_pr), num_workers=4, pin_memory=True))

for k in range(20):
    rgb_images = next(dataloader_pr)[0].to(device)

    skt_org_imgs = rgb_images[batch_size // 2:].clone()

    skt_imgs = net_2skt(vgg(F.interpolate(skt_org_imgs, 256), base=8)[2])
Exemplo n.º 14
0
def main():
    print("hello")
    args = parse_args()
    data_dir = 'args.data_dir'
    train_dir = data_dir + '/train'
    val_dir = data_dir + '/valid'
    test_dir = data_dir + '/test'
    training_transforms = transforms.Compose([
        transforms.RandomRotation(30),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    validataion_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    testing_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    image_datasets = [
        ImageFolder(train_dir, transform=training_transforms),
        ImageFolder(val_dir, transform=validataion_transforms),
        ImageFolder(test_dir, transform=testing_transforms)
    ]
    dataloaders = [
        torch.utils.data.DataLoader(image_datasets[0],
                                    batch_size=64,
                                    shuffle=True),
        torch.utils.data.DataLoader(image_datasets[1],
                                    batch_size=64,
                                    shuffle=True),
        torch.utils.data.DataLoader(image_datasets[2],
                                    batch_size=64,
                                    shuffle=True)
    ]
    model = getattr(models, args.arch)(pretrained=True)
    for param in model.parameters():
        param.requires_grad = False
    if args.arch == "vgg13":
        feature_num = model.classifier[0].in_features
        classifier = nn.Sequential(
            OrderedDict([('fc1', nn.Linear(feature_num, 1024)),
                         ('drop', nn.Dropout(p=0.5)), ('relu', nn.ReLU()),
                         ('fc2', nn.Linear(1024, 102)),
                         ('output', nn.LogSoftmax(dim=1))]))
    elif args.arch == "densenet121":
        classifier = nn.Sequential(
            OrderedDict([('fc1', nn.Linear(1024, 500)),
                         ('drop', nn.Dropout(p=0.6)), ('relu', nn.ReLU()),
                         ('fc2', nn.Linear(500, 102)),
                         ('output', nn.LogSoftmax(dim=1))]))

    model.classifier = classifier
    criterion = nn.NLLLoss()
    optimizer = optim.Adam(model.classifier.parameters(),
                           lr=float(args.learning_rate))
    epochs = int(args.epochs)
    class_index = image_datasets[0].class_to_idx
    gpu = args.gpu  # get the gpu settings
    train(model, criterion, optimizer, dataloaders, epochs, gpu)
    model.class_to_idx = class_index
    path = args.save_dir  # get the new save location
    save_checkpoint(path, model, optimizer, args, classifier)
def main():
    '''
    # 학습용 이미지를 무작위로 가져오기.
    dataiter = iter(trainloader)
    images, labels = dataiter.next()
    # 이미지 보여주기.
    imshow(make_grid(images, nrow = 10))
    # 정답(label) 출력.
    print(' '.join('%5s' % classes[labels[j]] for j in range(10)))
    '''

    # 데이터 전처리.
    # 정규화를 위해 128 x 128 사이즈로 리사이징.
    # 건축물 사진의 대부분이 가장자리가 하늘이나 구름 등 건물과는 상관없는 요소들이기 때문에 Center Crop 을 사용하여 100 x 100 사이즈로 크롭하여 샘플링함.
    trans = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.CenterCrop(100),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # 이미지 불러오기.
    trainset = ImageFolder(root=TRAIN_PATH, transform=trans)
    testset = ImageFolder(root=TEST_PATH, transform=trans)

    # Loader 정의.
    # 여기에서 데이터 라벨링 작업도 같이 함.
    # num_workers 값은 일반적으로 코어 개수의 절반정도 수치면 무난하게 시스템 리소스를 사용하여 학습이 가능하다고 하여 3 으로 설정.
    # data 의 개수가 얼마 안되기에 batch_size 를 크게 잡음. (batch_size=50).
    # 속도 증가를 위해 pin_memory 사용.
    trainloader = DataLoader(trainset,
                             batch_size=batch,
                             num_workers=3,
                             pin_memory=True,
                             shuffle=True)
    testloader = DataLoader(testset,
                            batch_size=batch,
                            num_workers=3,
                            shuffle=False)

    # 라벨 분류를 위한 작업.
    classes = ('Romanesque', 'Gothic', 'Renaissance', 'Baroque')

    # 네트워크 선택.
    net = My_Net()

    # 네트워크를 CUDA 장치로 보내기.
    net.to(device)

    # Loss Function 과 Optimizer 정의. (교차 엔트로피 손실(Cross-Entropy loss)과 모멘텀(momentum) 값을 갖는 SGD)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

    for epoch in range(epoch_total):  # 데이터셋을 수차례 반복. (epoch 10 회)

        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            # [inputs, labels]의 목록인 data로부터 입력을 받은 후;
            inputs, labels = data[0].to(device), data[1].to(device)

            # 변화도(Gradient) 매개변수를 0으로 만들고
            optimizer.zero_grad()

            # 순전파 + 역전파 + 최적화를 한 후
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # 통계를 출력.
            running_loss += loss.item()
            if i % 5 == 4:  # print every 5 mini-batches
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, i + 1, running_loss / 5))
                running_loss = 0.0

    print('Finished Training')

    # 학습한 데이터 저장. (PATH 는 코드 상단부에서 일괄로 정의)
    torch.save(net.state_dict(), SAVE_PATH)

    # ---------------------------- 시험용 데이터로 학습 결과 검사 ------------------------------
    # 테스트 이미지셋 불러오기.
    dataiter = iter(testloader)
    images, labels = dataiter.next()
    '''
    # 테스트 이미지셋 출력.
    plt.imshow(make_grid(images))
    print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
    '''

    # 저장된 학습 데이터 불러오기.
    net = My_Net()
    net.load_state_dict(torch.load(SAVE_PATH))

    # 신경망 예측 결과 출력.(부분 출력)
    outputs = net(images)
    outputs
    _, predicted = torch.max(outputs, 1)
    print('Predicted: ',
          ' '.join('%5s' % classes[predicted[j]] for j in range(10)))

    # 전체 테스트 데이터셋에 대한 정확도.
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 1000 test images: %d %%' %
          (100 * correct / total))

    # 라벨별 정확도.
    class_correct = list(0. for i in range(4))
    class_total = list(0. for i in range(4))
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = net(images)
            _, predicted = torch.max(outputs, 1)
            c = (predicted == labels).squeeze()
            for i in range(batch):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1

    for i in range(4):
        print('Accuracy of %5s : %2d %%' %
              (classes[i], 100 * class_correct[i] / class_total[i]))
Exemplo n.º 16
0
#from helpers.configuration_container import ConfigurationContainer
from torchvision.datasets import ImageFolder

from torchvision.transforms import ToTensor, Compose, Resize, Grayscale, Normalize
from torch.utils.data import Dataset
#from data.data_loader import DataLoader
from torchvision.utils import save_image

from PIL import Image
import torch
from torch.autograd import Variable

from imblearn.over_sampling import SMOTE

#dataset = ImageFolder(root="data/datasets/base_dir/train_dir", transform=Compose(transforms))
dataset = ImageFolder(root="./datasets/base_dir/train_dir",
                      transform=Compose([]))

print(dataset)
print(len(dataset))

mylabels = {}
totallabels = 109
tensor_list = []
labels_list = []
for img in dataset:
    tensor_list.append(img[0])
    labels_list.append(img[1])
    if not img[1] in mylabels:
        mylabels[img[1]] = 0
    if mylabels[img[1]] < totallabels:
        tensor_list.append(img[0])
Exemplo n.º 17
0
    def __init__(self, batch_size, transform, mode='train'):

        self.sub_meta = {}
        print("[SetDataset] mode:%s" % (mode))

        if mode == 'train':
            self.cl_list = range(64)  #64-classes
            d = ImageFolder(miniImageNet_path)
        elif mode == 'val':
            self.cl_list = range(16)  #16-classes
            d = ImageFolder(miniImageNet_val_path)
        else:
            self.cl_list = range(20)  #20-classes
            d = ImageFolder(miniImageNet_test_path)

        for cl in self.cl_list:
            self.sub_meta[cl] = []

        #=====================

        flag = False
        """
        for i in os.listdir("."):
            if mode+"_loader.npy" == i and mode == "train":
                # [FOR DEBUG PURPOSE]
                flag = True
                self.sub_meta = np.load(i,allow_pickle='TRUE').item()
                print("load dataset from %s"%(i))
        """
        if flag == False:
            for i, (data, label
                    ) in enumerate(d):  #this line needs to be accelerated!
                self.sub_meta[label].append(data)  #label2data dict
            """
            if mode == "train":
                for key, item in self.sub_meta.items(): self.sub_meta[key] = self.sub_meta[key][:150]
                np.save(mode+'_loader.npy', self.sub_meta)
                print("save to %s"%(mode+'_loader.npy'))
            """
        #=====================
        """
        for i, (data, label) in enumerate(d):#this line needs to be accelerated!
            self.sub_meta[label].append(data)#label2data dict
        """

        for key, item in self.sub_meta.items():  #0~64
            print(len(
                self.sub_meta[key]))  #number of items in the corresponding key
            #print "600" 200 times

        self.sub_dataloader = []
        sub_data_loader_params = dict(
            batch_size=batch_size,
            shuffle=True,
            num_workers=0,  #use main thread only or may receive multiple batches
            pin_memory=False)

        for cl in self.cl_list:
            sub_dataset = SubDataset(self.sub_meta[cl],
                                     cl,
                                     transform=transform)
            self.sub_dataloader.append(
                torch.utils.data.DataLoader(sub_dataset,
                                            **sub_data_loader_params))
from torchvision.models import resnet152
from urllib.request import urlopen

path = "data/boat/train"

# the number of images that will be processed in a single step
batch_size = 128
# the size of the images that we'll learn on - we'll shrink them from the original size for speed
image_size = (30, 100)

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor()  # convert to tensor
])

train_dataset = ImageFolder(path, transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_dataset = ImageFolder(path, transform)
val_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)

test_dataset = ImageFolder(path, transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

training_features = np.load('Resnet50Features/training_features.npy')
training_labels = np.load('Resnet50Features/training_labels.npy')

valid_features = np.load('Resnet50Features/valid_features.npy')
valid_labels = np.load('Resnet50Features/valid_labels.npy')

testing_features = np.load('Resnet50Features/testing_features.npy')
Exemplo n.º 19
0
 def create_dataset(self, root, transform):
     return ImageFolder(root=root, transform=transform)
Exemplo n.º 20
0
    model = torch.load(MODEL_PATH)
    model.to(DEVICE)

SIZE = 32
resize = Resize((SIZE, SIZE))
# norm = Normalize((0.5063, 0.5063, 0.5063), (0.2412, 0.2412, 0.2412))
# norm = Normalize((0.43767047, 0.44375867, 0.47279018), (0.19798356, 0.20096427, 0.19697163))
norm = Normalize((0.1307, 0.1307, 0.1307), (0.3081, 0.3081, 0.3081))
to_tensor = ToTensor()

true_labels = False
data_folder = 'dataset_gan183'
dataset_folder = f'data/mnist.vgg16/{data_folder}'
stolen_labels_folder = f'data/mnist.vgg16/{data_folder}_sl{"_ml" if true_labels else ""}'

dataset = ImageFolder(root=dataset_folder)

if not os.path.exists(stolen_labels_folder):
    os.mkdir(stolen_labels_folder)

real_labels = np.array([])
pred_labels = np.array([])

print('Generating labels from target...')
with torch.no_grad():
    model.eval()

    for image, label in tqdm(dataset):
        real_labels = np.append(real_labels, label)
        # image = image.to(DEVICE)
        image = resize(image)
Exemplo n.º 21
0
    args.center_loss_weight)

# create logger
logger = Logger(LOG_DIR)

kwargs = {'num_workers': 2, 'pin_memory': True} if args.cuda else {}
l2_dist = PairwiseDistance(2)

transform = transforms.Compose([
    transforms.Scale(96),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

train_dir = ImageFolder(args.dataroot, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dir,
                                           batch_size=args.batch_size,
                                           shuffle=True,
                                           **kwargs)

test_loader = torch.utils.data.DataLoader(LFWDataset(
    dir=args.lfw_dir, pairs_path=args.lfw_pairs_path, transform=transform),
                                          batch_size=args.batch_size,
                                          shuffle=False,
                                          **kwargs)


def main():
    test_display_triplet_distance = True
    # print the experiment configuration
Exemplo n.º 22
0
import matplotlib.pyplot as plt
from torchvision.models import alexnet
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder
import torch

if __name__ == "__main__":
    np.random.seed(65)
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    trans = transforms.Compose([transforms.ToTensor(), normalize])
    #dir_path = './hw2-4_data/problem2/'
    dir_path = sys.argv[1]
    batch_size = 1
    train_path = os.path.join(dir_path, 'train/')
    train_set = ImageFolder(train_path, transform=trans)
    train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True)
    valid_path = os.path.join(dir_path, 'valid/')
    valid_set =  ImageFolder(valid_path,  transform=trans)
    valid_loader  = torch.utils.data.DataLoader(dataset=valid_set,  batch_size=batch_size, shuffle=False)

    device = torch.device('cuda')
    extractor = alexnet(pretrained=True).features
    extractor.to(device)
    extractor.eval()

    train_feature, train_label = [], []
    with torch.no_grad():
        for img, label in train_loader:
            img, label = img.to(device), label.to(device)
            feat = extractor(img).view(img.size(0), 256, -1)
Exemplo n.º 23
0
def main():
    # Parse arguments
    parser = ArgumentParser()
    parser.add_argument("--imsize",
                        dest="imsize",
                        help="size of images used to train the CNN")
    parser.add_argument(
        "-o",
        "--out",
        dest="out_dir",
        help="location to which trained model will be exported")
    parser.add_argument("-e",
                        "--epochs",
                        dest="num_epochs",
                        help="number of epochs")
    imsize = parser.parse_args().imsize
    imsize = int(imsize)
    out_dir = parser.parse_args().out_dir
    num_epochs = parser.parse_args().num_epochs
    num_epochs = int(num_epochs)
    print(f'Imsize: {imsize}')
    print(f'Num epochs: {num_epochs}')

    json_path = path.join(out_dir, 'details.json')
    model_path = path.join(out_dir, 'model')

    # Fixed seed for reproducibility of results
    torch.manual_seed(42)

    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Device: {device}')

    # Load dataset
    full_dataset = ImageFolder(DATA_DIR)
    class_names = full_dataset.classes

    # Random split into training and testing dataset
    train_size = int(0.8 * len(full_dataset))
    test_size = len(full_dataset) - train_size
    train_subset, test_subset = random_split(full_dataset,
                                             [train_size, test_size])

    # Compute mean and std of training data
    train_copy = copy.copy(
        train_subset
    )  # Copy training dataset to prevent in-place modifications
    train_dataset = DatasetFromSubset(train_copy,
                                      transform=transforms.Compose([
                                          transforms.Resize((imsize, imsize)),
                                          transforms.ToTensor()
                                      ]))
    data_loader = DataLoader(train_dataset,
                             batch_size=4,
                             shuffle=False,
                             num_workers=8)
    train_mean, train_std = get_mean_std(data_loader)
    print(f'Training dataset:')
    print(f'\tmean: {train_mean.tolist()}')
    print(f'\tstd:  {train_std.tolist()}')

    json_data = {
        'train_mean': train_mean.tolist(),
        'train_std': train_std.tolist(),
        'imsize': imsize,
        'class_names': class_names
    }

    # Source: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
    # Data augmentation and normalization for training
    data_transforms = {
        'train':
        transforms.Compose([
            transforms.RandomResizedCrop(int(0.8 * imsize)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=train_mean, std=train_std)
        ]),
        'val':
        transforms.Compose([
            transforms.Resize(imsize),
            transforms.CenterCrop(int(0.8 * imsize)),
            transforms.ToTensor(),
            transforms.Normalize(mean=train_mean, std=train_std)
        ]),
    }

    train_dataset = DatasetFromSubset(train_subset,
                                      transform=data_transforms['train'])
    val_dataset = DatasetFromSubset(test_subset,
                                    transform=data_transforms['val'])
    dataloaders = {
        'train':
        DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=8),
        'val':
        DataLoader(val_dataset, batch_size=4, shuffle=True, num_workers=8)
    }
    dataset_sizes = {'train': len(train_subset), 'val': len(test_subset)}

    # Load pre-trained model
    model_ft = models.resnet18(pretrained=True)

    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Linear(num_ftrs, len(class_names))
    model_ft = model_ft.to(device)
    criterion = nn.CrossEntropyLoss()
    # Observe that all parameters are being optimized
    optimizer_ft = torch.optim.SGD(model_ft.parameters(),
                                   lr=0.001,
                                   momentum=0.9)

    # Decay LR by a factor of 0.1 every 7 epochs
    exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_ft,
                                                       step_size=7,
                                                       gamma=0.1)

    # Train model
    model_ft, accuracy = train_model(model_ft,
                                     criterion,
                                     optimizer_ft,
                                     exp_lr_scheduler,
                                     dataloaders,
                                     dataset_sizes,
                                     device=device,
                                     num_epochs=num_epochs)

    json_data['accuracy'] = accuracy.item()
    # Store data to JSON for API
    with open(json_path, 'w') as out:
        json.dump(json_data, out)

    # Save trained model
    torch.save(model_ft.state_dict(), model_path)
Exemplo n.º 24
0
preprocess_resnet = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

if 'vgg' in args.model:
    preprocess = preprocess_vgg
elif 'resnet' in args.model:
    preprocess = preprocess_resnet
elif 'densenet' in args.model:
    preprocess = preprocess_densenet

trainDataPath = trainDataPath = os.getcwd(
) + "/GroceryStoreDataset-master/dataset/train/Packages"
train_dataset = ImageFolder(root=trainDataPath, transform=preprocess)
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=args.train_batch_size,
                                           shuffle=True)
testDataPath = os.getcwd(
) + "/GroceryStoreDataset-master/dataset/test/Packages"
test_dataset = ImageFolder(root=testDataPath,
                           transform=preprocess)  #이거 test시에해야하나?
test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=args.test_batch_size,
                                          shuffle=False)
validDataPath = os.getcwd(
) + "/GroceryStoreDataset-master/dataset/valid/Packages"
valid_dataset = ImageFolder(root=validDataPath, transform=preprocess)
valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                           batch_size=args.test_batch_size,
Exemplo n.º 25
0
        self.hook.remove()


is_cuda = False
if torch.cuda.is_available():
    is_cuda = True

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

train = ImageFolder('digit_class_train/', train_transform)
valid = ImageFolder('digit_class_valid/', train_transform)

train_data_loader = torch.utils.data.DataLoader(train,
                                                batch_size=16,
                                                shuffle=True)
valid_data_loader = torch.utils.data.DataLoader(valid,
                                                batch_size=16,
                                                shuffle=True)

vgg = models.vgg16(pretrained=True)
vgg = vgg.cuda()

vgg.classifier[6].out_features = 10
for param in vgg.features.parameters():
    param.requires_grad = False  # features 는 grad가 update되지 못하게 막는다.
Exemplo n.º 26
0
 def load_dataset(self):
     self.data = ImageFolder(root=self.data_dir)
Exemplo n.º 27
0
def count_images(img_path):
    return len(glob(img_path + '/**/*.jpg'))


train_count = count_images(train_imgs)
test_count = count_images(test_imgs)

print('Image count for training [{}]'.format(train_count))
print('Image count for testing [{}]'.format(test_count))

# image formatter
train_transformer = generate_transformer(include_flip=True)
test_transformer = generate_transformer()

# data sets
train_dataset = ImageFolder(train_imgs, transform=train_transformer)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = ImageFolder(test_imgs, transform=test_transformer)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)

# assuming that both training and testing classes are the same
classes = train_dataset.classes
print('Classes [{}]'.format(classes))

# trianing the model will get the model instance
# updated by the reference
train_model(num_epochs, model, optimizer, loss_function, train_count,
            test_count, train_loader, test_loader)

print('Model {}', model)
Exemplo n.º 28
0
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 1),
        )

    def forward(self, h):
        y = self.layer(h)
        return y


source_dataset = ImageFolder('../real_or_drawing/train_data',
                             transform=source_transform)
target_dataset = ImageFolder('../real_or_drawing/test_data',
                             transform=target_transform)

source_dataloader = DataLoader(source_dataset, batch_size=32, shuffle=True)
target_dataloader = DataLoader(target_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(target_dataset, batch_size=128, shuffle=False)

feature_extractor = FeatureExtractor().cuda()
label_predictor = LabelPredictor().cuda()

feature_extractor.load_state_dict(torch.load('extractor_model.bin'))
label_predictor.load_state_dict(torch.load('predictor_model.bin'))

class_criterion = nn.CrossEntropyLoss()
Exemplo n.º 29
0
#数据预处理
data_transform = transforms.Compose([
    #transforms.Scale((224,224), 2),                           #对图像大小统一
    transforms.Resize([224, 224], 2),
    transforms.RandomHorizontalFlip(),  #图像翻转
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[  #图像归一化
            0.229, 0.224, 0.225
        ])
])

#获取数据集
#训练集
train_dataset = ImageFolder(root='work/data/train/', transform=data_transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
#验证集
val_dataset = ImageFolder(root='work/data/val/', transform=data_transform)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

#类别
data_classes = train_dataset.classes

#选择模型
net = models.alexnet()
net.classifier = nn.Sequential(
    nn.Dropout(),
    nn.Linear(256 * 6 * 6, 4096),
    nn.ReLU(inplace=True),
    nn.Dropout(),
def main(args):
    dtype = torch.FloatTensor
    if args.use_gpu:
        dtype = torch.cuda.FloatTensor

    train_transform = T.Compose([
        T.Scale(256),
        T.RandomSizedCrop(224),
        T.ToTensor(),
        T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
    ])

    # you can read more about the ImageFolder class here:
    # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
    train_dset = ImageFolder(args.train_dir, transform=train_transform)
    train_loader = DataLoader(train_dset,
                              batch_size=args.batch_size,
                              num_workers=args.num_workers,
                              shuffle=True)

    val_transform = T.Compose([
        T.Scale(224),
        T.CenterCrop(224),
        T.ToTensor(),
        T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
    ])

    val_dset = ImageFolder(args.val_dir, transform=val_transform)
    val_loader = DataLoader(val_dset,
                            batch_size=args.batch_size,
                            num_workers=args.num_workers,
                            shuffle=True)

    # First load the pretrained ResNet-18 model; this will download the model
    # weights from the web the first time you run it
    model = resnet18(pretrained=True)

    # Reinitialize the last layer of the model.
    # Each pretrained model has a slightly different structure, but from the Resnet class definition
    # we see that the final fully-connected layer is stored in model.fc:
    num_classes = len(train_dset.classes)
    model.fc = nn.Linear(model.fc.in_features, num_classes)

    # Cast the model to the correct datatype, and create a loss function fro training the model
    model.type(dtype)
    loss_fn = nn.CrossEntropyLoss().cuda()

    # First we want to train only the reinitialized last layer for a few epochs.
    # During this phase we do not need to compute gradients with respect to the
    # other weights of the model, so we set the requires_grad flag to False for
    # all model parameters, then set the requires_grad=True for the parameters in the
    # last layer only.
    for param in model.parameters():
        param.requires_grad = False
    for param in model.fc.parameters():
        param.requires_grad = True

    # Construct an Optimizer object for updating the last layer only.
    optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-5)

    # Train the entire model for a few more epochs, checking accuracy on the
    # train and validation sets after each epoch.
    for epoch in range(args.num_epochs2):
        print('Starting epoch %d / %d' % (epoch + 1, args.num_epochs2))
        run_epoch(model, loss_fn, train_loader, optimizer, dtype)

        train_acc = check_accuracy(model, train_loader, dtype)
        val_acc = check_accuracy(model, val_loader, dtype)

        print("train_Acc", train_acc)
        print("val_acc", val_acc)
        print("\n")