Exemplo n.º 1
0
 def __init__(self, data_root, effect):
     self.data_root = os.path.join(data_root, effect)
     self.data_figure_root = os.path.join(data_root, effect + '-figure')
     self.effect = effect
     self.data_list = sorted(os.listdir(self.data_root), key=lambda x: int(x.replace('.png', '')))
     self.data_figure_list = sorted(os.listdir(self.data_figure_root), key=lambda x: int(x.replace('.png', '')))
     self.effect_shape = cv2.imread(os.path.join(self.data_root, self.data_list[0])).shape[:2]
     self.figure_shape = cv2.imread(os.path.join(self.data_figure_root, self.data_figure_list[0])).shape[:2]
     # self.scaled_shape = (int(self.train_shape[0] * 0.5), int(self.train_shape[1] * 0.8))
     diff_h = (self.figure_shape[0] - self.effect_shape[0]) // 2
     diff_w = (self.figure_shape[1] - self.effect_shape[1]) // 2
     self.target_transforms = transforms.Compose([
         transforms.ToPILImage(),
         transforms.RandomAffine(degrees=(-20, 20), translate=(0.2, 0.3), scale=(0.5, 1.1), ),
     ])
     self.source_transforms = transforms.Compose([
         transforms.ToPILImage(),
         transforms.Pad((diff_w, diff_h, diff_w, diff_h)),  # 左,上,右,下
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
     ])
     self.final_transforms = transforms.Compose([
         transforms.ToPILImage(),
         transforms.RandomAffine(degrees=(-20, 20), translate=(0.2, 0.5), scale=(0.5, 1.1), ),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
     ])
     self.length = len(self.data_list) - 1
Exemplo n.º 2
0
    def __getitem__(self, index):
        image_path, labels = self.imgs[index]
        em_label = np.zeros(109)
        em_label[labels] = 1  # 单标签
        # em_label = labels

        img = Image.open(image_path)
        img = img.convert("RGB")
        max_side_dest_length = self.input_size
        max_side_length = max(img.size)
        ratio = max_side_dest_length / max_side_length
        new_size = [int(ratio * x) for x in img.size[::-1]]
        new_size = [299, 299]

        h = new_size[0]
        w = new_size[1]
        pad = [0, 0]

        if h > w:
            pad = [(h - w), 0, 0, 0]
        else:
            pad = [0, (w - h), 0, 0]

        img_loader = transforms.Compose([
            transforms.Resize(new_size),
            transforms.Pad(padding=tuple(pad)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        image = img_loader(img)
        return image, torch.LongTensor(em_label)  # LongTensor
Exemplo n.º 3
0
    def __call__(self, params):
        sample, coords = params
        width, height = sample.size
        width, height = cut_down(sample, width, height)

        x1, y1, x2, y2 = coords
        x = x2 - x1
        y = y2 - y1

        padding = (0, 0)
        pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0

        if x > y:
            y2_ = int((y1 + y2 + x) / 2.)
            y1_ = int((y1 + y2 - x) / 2.)
            x1_ = x1
            x2_ = x2

        elif x <= y:
            x2_ = int((x1 + x2 + y) / 2.)
            x1_ = int((x1 + x2 - y) / 2.)
            y1_ = y1
            y2_ = y2

        if y1_ < 0:
            pad_top = -y1_
            y1_ = 0
        if y2_ > height:
            pad_bottom = y2_ - height
            y2_ = height
        if x1_ < 0:
            pad_left = -x1_
            x1_ = 0
        if x2_ > width:
            pad_right = x2_ - width
            x2_ = width

        padding = (int(
            (pad_left + pad_right) / 2.), int((pad_bottom + pad_top) / 2.))
        coords = (x1_, y1_, x2_, y2_)
        """
        print(padding)
        print(coords)
        print((x2_ - x1_, y2_ - y1_))
        """

        sample = sample.crop(coords)

        if sum(padding) > 0:
            # tuple(ImageStat.Stat(sample).mean)
            pad = transforms.Pad(
                padding, fill=(100, 100, 100), padding_mode='constant'
            )  # padding_mode IN ['constant', 'edge', 'reflect', 'symmetric']
            sample = pad(sample)

        #sample.save("/model/test_square/img/" + str(randint(0,1e6)) + ".jpg")
        return sample
Exemplo n.º 4
0
 def __init__(self, model):
     self.model = model
     self.device = next(model.parameters()).device
     self.transform = transforms.Compose([
         transforms.Pad(4),
         transforms.RandomHorizontalFlip(),
         transforms.RandomCrop(32),
         transforms.ToTensor()
     ])
     self.stat = Counter()
     self.path = "./data/models_dict/%s.ckpt" % self.model.__class__.__name__
Exemplo n.º 5
0
    def setup(self, stage: Optional[str] = None) -> None:
        base_aug = [transforms.ToTensor()]
        base_aug.insert(0, transforms.Pad(2))
        train_data = MNIST(root=self.data_dir, download=False, train=True)
        test_data = MNIST(root=self.data_dir, download=False, train=False)

        self._filter(train_data)
        self._filter(test_data)

        train_len = int(len(train_data))
        num_train, _ = self._get_splits(train_len, self.val_split)

        g_cpu = torch.Generator()
        g_cpu = g_cpu.manual_seed(self.seed)
        train_data, val_data = random_split(
            train_data,
            lengths=(num_train, train_len - num_train),
            generator=g_cpu,
        )

        colorizer = LdColorizer(
            scale=self.scale,
            background=False,
            black=True,
            binarize=False,
            greyscale=False,
            color_indices=self.colours,
        )

        self._train_data = LdAugmentedDataset(
            train_data,
            ld_augmentations=colorizer,
            num_classes=self.num_classes,
            num_colours=self.num_colours,
            li_augmentation=False,
            base_augmentations=base_aug,
            correlation=self.correlation,
        )
        self._val_data = LdAugmentedDataset(
            val_data,
            ld_augmentations=colorizer,
            num_classes=self.num_classes,
            num_colours=self.num_colours,
            li_augmentation=True,
            base_augmentations=base_aug,
        )
        self._test_data = LdAugmentedDataset(
            test_data,
            ld_augmentations=colorizer,
            num_classes=self.num_classes,
            num_colours=self.num_colours,
            li_augmentation=True,
            base_augmentations=base_aug,
        )
Exemplo n.º 6
0
    def _preprocess(self, x: np.ndarray, mask: Optional[np.ndarray]):
        # x, _ = self.normalize_fn(x, x)
        # if mask is None:
        #     mask = np.ones_like(x, dtype=np.float32)
        # else:
        #     mask = np.round(mask.astype('float32') / 255)

        _, h, w = x.shape
        block_size = 32
        min_height = (h // block_size + 1) * block_size
        min_width = (w // block_size + 1) * block_size

        # pad_params = {'mode': 'constant',
        #               'constant_values': 0,
        #               'pad_width': ((0, 0), (0, min_height - h), (0, min_width - w))
        #               }

        # x = np.pad(x, **pad_params)
        # mask = np.pad(mask, **pad_params)
        x = transforms.Pad((0, 0, min_height - h, min_width - w))(x)
        x = torch.unsqueeze(x, 0)
        mask = torch.ones_like(x)
        return (x, mask), h, w
Exemplo n.º 7
0
    def __init__(self,
                 size=(256, 128),
                 random_horizontal_flip=0,
                 pad=0,
                 normalize=True,
                 random_erase=0):
        """
        :param size:
        :param random_horizontal_flip: strong baseline = 0.5
        :param pad: strong baseline = 10
        :param normalize:
        :param random_erase: strong baseline = 0.5
        """
        transforms_list = list()
        transforms_list.append(transforms.Resize(size))

        if random_horizontal_flip:
            transforms_list.append(
                transforms.RandomHorizontalFlip(random_horizontal_flip))

        if pad:
            transforms_list.append(transforms.Pad(pad))
            transforms_list.append(transforms.RandomCrop(size))

        transforms_list.append(transforms.ToTensor())

        if normalize:
            transforms_list.append(self.normalization)

        if random_erase:
            transforms_list.append(
                transforms.RandomErasing(random_erase, self.random_erase_scale,
                                         self.random_erase_ratio,
                                         self.random_erase_value))

        super().__init__(transforms_list)
Exemplo n.º 8
0
def create_cmnist_datasets(
    *,
    root: str,
    scale: float,
    train_pcnt: float,
    download: bool = False,
    seed: int = 42,
    rotate_data: bool = False,
    shift_data: bool = False,
    padding: bool = False,
    quant_level: int = 8,
    input_noise: bool = False,
    classes_to_keep: Optional[Sequence[_Classes]] = None,
) -> Tuple[LdTransformedDataset, LdTransformedDataset]:
    """Create and return colourised MNIST train/test pair.

    Args:
        root: Where the images are downloaded to.
        scale: The amount of 'bias' in the colour. Lower is more biased.
        train_pcnt: The percentage of data to make the test set.
        download: Whether or not to download the data.
        seed: Random seed for reproducing results.
        rotate_data: Whether or not to rotate the training images.
        shift_data: Whether or not to shift the training images.
        padding: Whether or not to pad the training images.
        quant_level: the number of bins to quantize the data into.
        input_noise: Whether or not to add noise to the training images.
        classes_to_keep: Which digit classes to keep. If None or empty then all classes will be kept.

    Returns:
        tuple of train and test data as a Dataset.
    """
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

    base_aug = [transforms.ToTensor()]
    data_aug = []
    if rotate_data:
        data_aug.append(transforms.RandomAffine(degrees=15))
    if shift_data:
        data_aug.append(
            transforms.RandomAffine(degrees=0, translate=(0.11, 0.11)))
    if padding > 0:
        base_aug.insert(0, transforms.Pad(padding))
    if quant_level != 8:
        base_aug.append(Quantize(int(quant_level)))
    if input_noise:
        base_aug.append(NoisyDequantize(int(quant_level)))

    mnist_train = MNIST(root=root, train=True, download=download)
    mnist_test = MNIST(root=root, train=False, download=download)

    if classes_to_keep:
        mnist_train = _filter_classes(dataset=mnist_train,
                                      classes_to_keep=classes_to_keep)
        mnist_test = _filter_classes(dataset=mnist_test,
                                     classes_to_keep=classes_to_keep)

    all_data: ConcatDataset = ConcatDataset([mnist_train, mnist_test])
    train_data, test_data = train_test_split(all_data, train_pcnt=train_pcnt)

    colorizer = LdColorizer(scale=scale,
                            background=False,
                            black=True,
                            binarize=True,
                            greyscale=False)
    train_data = DatasetWrapper(train_data, transform=base_aug + data_aug)
    train_data = LdTransformedDataset(
        dataset=train_data,
        ld_transform=colorizer,
        target_dim=10,
        label_independent=False,
        discrete_labels=True,
    )
    test_data = DatasetWrapper(test_data, transform=base_aug)
    test_data = LdTransformedDataset(
        test_data,
        ld_transform=colorizer,
        target_dim=10,
        label_independent=True,
        discrete_labels=True,
    )

    return train_data, test_data
Exemplo n.º 9
0
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torchvision import datasets
from torchvision.transforms import transforms

from backpack import backpack, extend
from benchmark_networks import net_cifar100_allcnnc, net_cifar10_3c3d, net_cifar10_3c3d_small, net_fmnist_2c2d

"""
Data loading
"""

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

cifar_transform = transforms.Compose([
    transforms.Pad(padding=2),
    transforms.RandomCrop(size=(32, 32)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=63. / 255.,
                           saturation=[0.5, 1.5],
                           contrast=[0.2, 1.8]),
    transforms.ToTensor(),
    transforms.Normalize((0.49139968, 0.48215841, 0.44653091),
                         (0.24703223, 0.24348513, 0.26158784))
])


def make_loader_for_dataset(dataset):
    def loader(batch_size):
        return DataLoader(
            dataset,
Exemplo n.º 10
0
 def Pad(self, **args):
     return self._add(transforms.Pad(**args))
Exemplo n.º 11
0
def load_dataset(data_dir, resize, dataset_name, img_type):
    if dataset_name == 'cifar_10':
        mean = cifar_10['mean']
        std = cifar_10['std']
    elif dataset_name == 'cifar_100':
        mean = cifar_100['mean']
        std = cifar_100['std']
    else:
        print(
            'Dataset not recognized. Data normalize with equal mean/std weights'
        )
        mean = [0.5, 0.5, 0.5]
        std = [0.5, 0.5, 0.5]
    hdf5_folder = '{}/hdf5'.format(data_dir)
    if os.path.exists(hdf5_folder):
        shutil.rmtree(hdf5_folder)
    create_hdf5(data_dir, resize, dataset_name, img_type)
    train_transform = transforms.Compose([
        transforms.Pad(4),
        transforms.RandomCrop(resize),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])
    test_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean=mean, std=std)])

    if isinstance(C.get()['aug'], list):
        logger.debug('augmentation provided.')
        train_transform.transforms.insert(0, Augmentation(C.get()['aug']))
    else:
        logger.debug('augmentation: %s' % C.get()['aug'])
        if C.get()['aug'] == 'random2048':
            train_transform.transforms.insert(
                0, Augmentation(random_search2048()))
        elif C.get()['aug'] == 'fa_reduced_cifar10':
            train_transform.transforms.insert(
                0, Augmentation(fa_reduced_cifar10()))
        elif C.get()['aug'] == 'fa_reduced_imagenet':
            train_transform.transforms.insert(
                0, Augmentation(fa_reduced_imagenet()))

        elif C.get()['aug'] == 'arsaug':
            train_transform.transforms.insert(0, Augmentation(arsaug_policy()))
        elif C.get()['aug'] == 'autoaug_cifar10':
            train_transform.transforms.insert(
                0, Augmentation(autoaug_paper_cifar10()))
        elif C.get()['aug'] == 'autoaug_extend':
            train_transform.transforms.insert(0,
                                              Augmentation(autoaug_policy()))
        elif C.get()['aug'] in ['default', 'inception', 'inception320']:
            pass
        else:
            raise ValueError('not found augmentations. %s' % C.get()['aug'])

    if C.get()['cutout'] > 0:
        train_transform.transforms.append(CutoutDefault(C.get()['cutout']))

    hdf5_folder = '{}/hdf5'.format(data_dir)
    hdf5_train_path = '{}/{}_{}.hdf5'.format(hdf5_folder, dataset_name,
                                             'training')
    hdf5_test_path = '{}/{}_{}.hdf5'.format(hdf5_folder, dataset_name, 'test')
    train_dataset = CustomDataset(hdf5_file=hdf5_train_path,
                                  transform=train_transform)
    val_dataset = CustomDataset(hdf5_file=hdf5_train_path,
                                transform=test_transform)
    test_dataset = CustomDataset(hdf5_file=hdf5_test_path,
                                 transform=test_transform)

    train_dataset.train_labels = train_dataset.labels_id
    return [train_dataset, val_dataset, test_dataset]
Exemplo n.º 12
0
import os
from PIL import Image
import numpy as np
import os.path as osp
import torch
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
from torchvision.transforms import transforms

transform_train_list = [
    #transforms.RandomResizedCrop(size=128, scale=(0.75,1.0), ratio=(0.75,1.3333), interpolation=3), #Image.BICUBIC)
    transforms.Resize((256, 128), interpolation=3),
    transforms.Pad(10),
    transforms.RandomCrop((256, 128)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]

transform_test_list = [
    transforms.Resize(size=(256, 128), interpolation=3),  #Image.BICUBIC
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]

data_transform = {'train': transform_train_list, 'test': transform_test_list}


class ImageDataset(Dataset):
    def __init__(self, dataset, transformer=None):
        self.dataset = dataset
Exemplo n.º 13
0
    #存储目录不存在,就创建
    snapshot_prefix = snapshot_prefix.replace('\\', '/')
    snapshot_folder = '/'.join(snapshot_prefix.split('/')[:-1])
    if not os.path.exists(snapshot_folder):
        os.makedirs(snapshot_folder)

    #是否cifar10
    if args.cifar10 == '1':
        print('is cifar10')
        train_loader = torch.utils.data.DataLoader(
            torchvision.datasets.CIFAR10('./data',
                                         train=True,
                                         download=False,
                                         transform=transforms.Compose([
                                             transforms.Pad(4),
                                             transforms.RandomCrop(32),
                                             transforms.RandomHorizontalFlip(),
                                             transforms.ToTensor(),
                                             transforms.Normalize(
                                                 (0.4914, 0.4822, 0.4465),
                                                 (0.2023, 0.1994, 0.2010))
                                         ])),
            batch_size=batch_size,
            shuffle=True,
            num_workers=0)

        test_loader = torch.utils.data.DataLoader(torchvision.datasets.CIFAR10(
            './data',
            train=False,
            transform=transforms.Compose([
Exemplo n.º 14
0
def train(hyper_param_dict, model, device):

    transform = transforms.Compose([
        transforms.Pad(4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32),
        transforms.ToTensor()
    ])

    train_dataset = datasets.CIFAR10(root=hyper_param_dict['data root'],
                                     train=True,
                                     transform=transform,
                                     download=True)

    test_dataset = datasets.CIFAR10(root=hyper_param_dict['data root'],
                                    train=False,
                                    transform=transforms.ToTensor())

    loss_function = torch.nn.CrossEntropyLoss()
    if hyper_param_dict['optimizer'] == 'Adam':
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=hyper_param_dict['lr'],
            betas=(hyper_param_dict['beta1'], hyper_param_dict['beta2']),
            weight_decay=hyper_param_dict['weight_decay'])
    else:
        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=hyper_param_dict['lr'],
            momentum=hyper_param_dict['momentum'],
            weight_decay=hyper_param_dict['weight_decay'])
    scheduler = lr_scheduler(hyper_param_dict, optimizer)

    model.train()
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=hyper_param_dict['batch'],
        shuffle=True,
        num_workers=3)
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=hyper_param_dict['batch'],
        shuffle=False,
        num_workers=3)

    iters = len(train_loader)
    train_epoch = hyper_param_dict['epochs']

    ################# Make save directory
    if not os.path.isdir(hyper_param_dict['root dir']):
        os.makedirs(hyper_param_dict['root dir'])

    project_name = hyper_param_dict['project']
    now = time.localtime()
    save_time = str(now.tm_mday) + '_' + str(now.tm_hour) + '_' + str(
        now.tm_min)
    if project_name is not None:
        save_dir = os.path.join(hyper_param_dict['root dir'], project_name,
                                save_time)
    else:
        save_dir = os.path.join(hyper_param_dict['root dir'], save_time)

    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    ############ Save parameter
    with open(save_dir + '/hyper.pickle', 'wb') as fw:
        pickle.dump(hyper_param_dict, fw)



    train_acc_list, test_acc_list, train_loss_list, test_loss_list = [], [], [], []
    best = 0
    start = time.time()
    for epoch in range(train_epoch):
        # start training
        model.train()
        train_loss, train_acc, total, correct = 0, 0, 0, 0

        train_start = time.time()

        for i, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            output = model(images)
            train_loss = loss_function(output, labels)
            train_loss.backward()
            optimizer.step()

            if hyper_param_dict['lr scheduler'] == 'cos warm up':
                lr_step = epoch + i / iters
                scheduler.step(epoch=lr_step)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(labels.view_as(pred)).sum().item()
            total += labels.size(0)
        # scheduler.step()

        train_end = time.time()

        train_acc = correct / total * 100.
        # print(scheduler.get_lr()[0])
        print(
            "\nEpoch [{}]\n[Train set] Loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)  Epoch time : {:.4f} "
            .format(epoch + 1, train_loss.item(), correct, total, train_acc,
                    train_end - train_start),
            end='\n')
        print(f"   LR : {scheduler.get_last_lr()[0]:.4f}")
        #save train result
        train_loss_list.append(train_loss / total)
        train_acc_list.append(train_acc)

        # start evaluation
        model.eval()
        test_loss, test_correct, test_total = 0, 0, 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)

                output = model(images)
                test_loss += loss_function(output, labels).item()

                pred = output.max(1, keepdim=True)[1]
                test_correct += pred.eq(labels.view_as(pred)).sum().item()

                test_total += labels.size(0)
        test_acc = 100. * test_correct / test_total

        print('[Test set] Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.
              format(test_loss / test_total, test_correct, test_total,
                     test_acc))
        #save test result
        test_loss_list.append(test_loss / test_total)
        test_acc_list.append(test_acc)

        #save best model
        if best < test_acc:
            best = test_acc
            save_file_name = save_dir + '/' + project_name + '_best.pt'
            torch.save(model.state_dict(), save_file_name)
            print('save by ' + save_file_name)

    end = time.time()
    print("Time ellapsed in training is: {}".format(end - start))

    hyper_param_dict['training time'] = end - start
    '''
    Step 5
    '''
    # save file by numpy
    np_train_acc_list, np_test_acc_list = np.array(train_acc_list), np.array(
        test_acc_list)
    np_train_loss_list, np_test_loss_list = np.array(
        train_loss_list), np.array(test_loss_list)

    print(f'save dir : {save_dir}')
    np.save(save_dir + '/np_train_acc_list', np_train_acc_list)
    np.save(save_dir + '/np_train_loss_list', np_train_loss_list)
    np.save(save_dir + '/np_test_acc_list', np_test_acc_list)
    np.save(save_dir + '/np_test_loss_list', np_test_loss_list)
Exemplo n.º 15
0
def main():
    global dev
    dev = 'cuda' if torch.cuda.is_available() else 'cpu'

    num_epochs = 80
    learning_rate = 0.001
    batch_size = 128

    train_dataset = torchvision.datasets.CIFAR10(
        root="./data/CIFAR10/",
        train=True,
        transform=transforms.ToTensor(),
        download=True)
    test_dataset = torchvision.datasets.CIFAR10(
        root="./data/CIFAR10/",
        train=False,
        transform=transforms.ToTensor(),
        download=True)
    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=batch_size,
                             shuffle=False)

    # Image preprocessing modules
    # .Compose(): Composes several transforms in a list together
    transform = transforms.Compose([
        transforms.Pad(padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(size=32),
        transforms.ToTensor()
    ])

    model = ResNet(ResidualBlock, [2, 2, 2]).to(dev)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    cur_lr = learning_rate
    step = 0
    writer = SummaryWriter("./logs/resnet/")

    s_time = time.time()
    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(train_loader):
            images = images.to(dev)
            labels = labels.to(dev)

            output = model(images)
            loss = loss_fn(output, labels)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if step == 0:
                writer.add_graph(model, images)

            if (step + 1) % 100 == 0:
                writer.add_scalar('loss', loss.item(), step)

                for tag, value in model.named_parameters():
                    tag = tag.replace('.', '/')
                    writer.add_histogram(tag, value.data.cpu().numpy(), step)
                    writer.add_histogram(tag,
                                         value.grad.data.cpu().numpy(), step)

                print("Epoch-Step {}/{}-{} | Loss {}".format(
                    epoch, num_epochs, step, loss.item()))

            step += 1

        if (epoch + 1) % 20 == 0:
            cur_lr = cur_lr / 3
            update_lr(optimizer, cur_lr)

    e_time = time.time()
    train_time = e_time - s_time
    print("Training takes {}s".format(train_time))

    # model.load_state_dict(torch.load("./logs/resnet/params.ckpt"))
    with torch.no_grad():
        correct_num = 0
        total_num = 0
        model.eval()
        for images, labels in test_loader:
            images = images.to(dev)
            labels = labels.to(dev)

            output = model(images)
            _, pred = torch.max(output, dim=1)
            correct_num += (pred.squeeze() == labels.squeeze()).sum()
            total_num += labels.size(0)
        print("Accuracy is {} %".format(100 * correct_num / total_num))

    torch.save(model.state_dict(), "./logs/resnet/params.ckpt")
    writer.close()
Exemplo n.º 16
0
def main():
    args = parse_args()
    print(vars(args))

    cuda = torch.cuda.is_available()
    if cuda:
        print('Device: {}'.format(torch.cuda.get_device_name(0)))

    # 超参数设置
    epoch = args.epochs
    num_classes = 10
    batch_size = 128
    learning_rate = 1e-4

    # cifar10 分类索引
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck')

    # 数据增广方法
    transform = transforms.Compose([
        # +4填充至36x36
        transforms.Pad(4),
        # 随机水平翻转
        transforms.RandomHorizontalFlip(),
        # 随机裁剪至32x32
        transforms.RandomCrop(32),
        # 转换至Tensor
        transforms.ToTensor(),
        #  归一化
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    # cifar10路径
    cifar10Path = './datasets/cifar-10-batches-py/'

    #  训练数据集
    train_dataset = torchvision.datasets.CIFAR10(root=cifar10Path,
                                                 train=True,
                                                 transform=transform,
                                                 download=True)

    # 测试数据集
    test_dataset = torchvision.datasets.CIFAR10(root=cifar10Path,
                                                train=False,
                                                transform=transform)

    # 生成数据加载器
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True)
    # 测试数据加载器
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=batch_size,
                                              shuffle=False)

    # model part
    model = EmbeddingNet(args.dims)
    try:
        model.load_state_dict(torch.load('model/params_' + str(epoch) +
                                         '.pkl'))
    except:
        print("initialize...")

    if cuda:
        model = model.cuda()
    print(model)
    #criterion = OnlineTripletLoss(margin=1.0)
    criterion = TripletLoss(margin=1.0)
    optimizer = Adam(model.parameters(), lr=learning_rate)
    scheduler = StepLR(optimizer, 8, gamma=0.1, last_epoch=-1)
    fit(train_loader, test_loader, model, criterion, optimizer, scheduler,
        epoch, cuda)

    epoch += 100
    # for plot train embedding
    train_embeddings, train_targets = extract_embeddings(
        train_loader, model, cuda)
    plot_embeddings(train_loader.dataset,
                    train_embeddings,
                    train_targets,
                    title='train_' + str(epoch))

    # for plot test emebedding
    test_embeddings, test_targets = extract_embeddings(test_loader, model,
                                                       cuda)
    plot_embeddings(test_loader.dataset,
                    test_embeddings,
                    test_targets,
                    title='test_' + str(epoch))

    torch.save(model.state_dict(), 'model/params_' + str(epoch) + '.pkl')