Beispiel #1
0
def generate_input_transforms(batch: int, height: int, width: int,
                              channels: int, pad: int) -> transforms.Compose:
    """
    Generates a torchvision transformation converting a PIL.Image into a
    tensor usable in a network forward pass.

    Args:
        batch (int): mini-batch size
        height (int): height of input image in pixels
        width (int): width of input image in pixels
        channels (int): color channels of input
        pad (int): Amount of padding on horizontal ends of image

    Returns:
        A torchvision transformation composition converting the input image to
        the appropriate tensor.
    """
    scale = 0  # type: Union[Tuple[int, int], int]
    if height == 1 and width == 0 and channels > 3:
        perm = (1, 0, 2)
        scale = channels
        mode = 'L'
    # arbitrary (or fixed) height and width and channels 1 or 3 => needs a
    # summarizing network (or a not yet implemented scale operation) to move
    # height to the channel dimension.
    elif height > 1 and width == 0 and channels in (1, 3):
        perm = (0, 1, 2)
        scale = height
        mode = 'RGB' if channels == 3 else 'L'
    # fixed height and width image => bicubic scaling of the input image, disable padding
    elif height > 0 and width > 0 and channels in (1, 3):
        perm = (0, 1, 2)
        pad = 0
        scale = (height, width)
        mode = 'RGB' if channels == 3 else 'L'
    elif height == 0 and width == 0 and channels in (1, 3):
        perm = (0, 1, 2)
        pad = 0
        scale = 0
        mode = 'RGB' if channels == 3 else 'L'
    else:
        raise KrakenInputException(
            'Invalid input spec (variable height and fixed width not supported)'
        )

    out_transforms = []
    out_transforms.append(transforms.Lambda(lambda x: x.convert(mode)))
    if scale:
        if isinstance(scale, int):
            if mode not in ['1', 'L']:
                raise KrakenInputException(
                    'Invalid mode {} for line dewarping'.format(mode))
            lnorm = CenterNormalizer(scale)
            out_transforms.append(
                transforms.Lambda(lambda x: dewarp(lnorm, x)))
            out_transforms.append(transforms.Lambda(lambda x: x.convert(mode)))
        elif isinstance(scale, tuple):
            out_transforms.append(transforms.Resize(scale, Image.LANCZOS))
    if pad:
        out_transforms.append(transforms.Pad((pad, 0), fill=255))
    out_transforms.append(transforms.ToTensor())
    # invert
    out_transforms.append(transforms.Lambda(lambda x: x.max() - x))
    out_transforms.append(transforms.Lambda(lambda x: x.permute(*perm)))
    return transforms.Compose(out_transforms)
Beispiel #2
0
    def __init__(self, args):

        self.dataset = args['data']['dataset']
        self.batch_size = eval(args['data']['batch_size'])
        tanh = eval(args['data']['tanh_augmentation'])
        self.sigma = eval(args['data']['noise_amplitde'])
        unif = eval(args['data']['dequantize_uniform'])
        label_smoothing = eval(args['data']['label_smoothing'])
        channel_pad = eval(args['data']['pad_noise_channels'])
        channel_pad_sigma = eval(args['data']['pad_noise_std'])

        self.handwriting_type = 'None'
        if args['data'].get('handwriting_type'):
            self.handwriting_type = args['data']['handwriting_type']

        if self.dataset == 'MNIST':
            beta = 0.5
            gamma = 2.
        else:
            beta = torch.Tensor((0.4914, 0.4822, 0.4465)).view(-1, 1, 1)
            gamma = 1. / torch.Tensor((0.247, 0.243, 0.261)).view(-1, 1, 1)

        self.train_augmentor = Augmentor(False, self.sigma, unif, beta, gamma,
                                         tanh, channel_pad, channel_pad_sigma)
        self.test_augmentor = Augmentor(True, 0., unif, beta, gamma, tanh,
                                        channel_pad, channel_pad_sigma)
        self.transform = T.Compose([T.ToTensor(), self.test_augmentor])

        if self.handwriting_type == 'OPERATOR':
            print("Dataset used is operators")
            self.dims = (28, 28)
            if channel_pad:
                raise ValueError(
                    'needs to be fixed, channel padding does not work with mnist'
                )
            self.channels = 1
            self.n_classes = 12
            self.label_mapping = list(range(self.n_classes))
            self.label_augment = LabelAugmentor(self.label_mapping)
            train_csv_path = '/home/kaushikdas/aashish/pytorch_datasets/OPERATOR/handwriting_operators_train_temp.csv'
            test_csv_path = '/home/kaushikdas/aashish/pytorch_datasets/OPERATOR/handwriting_operators_test_temp.csv'

            self.test_data = HandwritingDataset(
                test_csv_path,
                transform=T.Compose([T.ToTensor(), self.test_augmentor]),
                target_transform=self.label_augment)

            self.train_data = HandwritingDataset(
                train_csv_path,
                transform=T.Compose([T.ToTensor(), self.train_augmentor]),
                target_transform=self.label_augment)
        elif self.handwriting_type == 'LETTER':
            print("Dataset used is letters")
            self.dims = (28, 28)
            if channel_pad:
                raise ValueError(
                    'needs to be fixed, channel padding does not work with mnist'
                )
            self.channels = 1
            self.n_classes = 26
            self.label_mapping = list(range(self.n_classes))
            self.label_augment = LabelAugmentor(self.label_mapping)
            train_csv_path = '/home/kaushikdas/aashish/pytorch_datasets/LETTER/handwriting_letters_train.csv'
            test_csv_path = '/home/kaushikdas/aashish/pytorch_datasets/LETTER/handwriting_letters_test.csv'

            self.test_data = HandwritingDataset(
                test_csv_path,
                transform=T.Compose([T.ToTensor(), self.test_augmentor]),
                target_transform=self.label_augment)

            self.train_data = HandwritingDataset(
                train_csv_path,
                transform=T.Compose([T.ToTensor(), self.train_augmentor]),
                target_transform=self.label_augment)
        elif self.handwriting_type == 'EMNIST_LETTER':
            print("Dataset used is emnist letters")
            self.dims = (28, 28)
            if channel_pad:
                raise ValueError(
                    'needs to be fixed, channel padding does not work with mnist'
                )
            self.channels = 1
            self.n_classes = 26
            self.label_mapping = list(range(self.n_classes))
            self.label_augment = LabelAugmentor(self.label_mapping)
            train_csv_path = '/home/kaushikdas/aashish/pytorch_datasets/EMNIST_LETTER/emnist_train.csv'
            test_csv_path = '/home/kaushikdas/aashish/pytorch_datasets/EMNIST_LETTER/emnist_test.csv'

            self.test_data = HandwritingDataset(
                test_csv_path,
                transform=T.Compose([T.ToTensor(), self.test_augmentor]),
                target_transform=self.label_augment)

            self.train_data = HandwritingDataset(
                train_csv_path,
                transform=T.Compose([T.ToTensor(), self.train_augmentor]),
                target_transform=self.label_augment)
        elif self.handwriting_type == 'EMNIST_UPPERCASE_LETTER':
            print("Dataset used is emnist uppercase letters")
            self.dims = (28, 28)
            if channel_pad:
                raise ValueError(
                    'needs to be fixed, channel padding does not work with mnist'
                )
            self.channels = 1
            self.n_classes = 26
            self.label_mapping = list(range(self.n_classes))
            self.label_augment = LabelAugmentor(self.label_mapping)
            train_csv_path = '/home/kaushikdas/aashish/pytorch_datasets/EMNIST_UPPERCASE_LETTER/emnist_uppercase_train_4th_May_2021.csv'
            test_csv_path = '/home/kaushikdas/aashish/pytorch_datasets/EMNIST_UPPERCASE_LETTER/emnist_uppercase_test_3rd_May_2021.csv'

            self.test_data = HandwritingDataset(
                test_csv_path,
                transform=T.Compose([T.ToTensor(), self.test_augmentor]),
                target_transform=self.label_augment)

            self.train_data = HandwritingDataset(
                train_csv_path,
                transform=T.Compose([T.ToTensor(), self.train_augmentor]),
                target_transform=self.label_augment)
        elif self.handwriting_type == 'EMNIST_LOWERCASE_LETTER':
            print("Dataset used is emnist lowercase letters")
            self.dims = (28, 28)
            if channel_pad:
                raise ValueError(
                    'needs to be fixed, channel padding does not work with mnist'
                )
            self.channels = 1
            self.n_classes = 26
            self.label_mapping = list(range(self.n_classes))
            self.label_augment = LabelAugmentor(self.label_mapping)
            train_csv_path = '/home/kaushikdas/aashish/pytorch_datasets/EMNIST_LOWERCASE_LETTER/emnist_lowercase_train_13th_May.csv'
            test_csv_path = '/home/kaushikdas/aashish/pytorch_datasets/EMNIST_LOWERCASE_LETTER/emnist_lowercase_test_13th_May.csv'

            self.test_data = HandwritingDataset(
                test_csv_path,
                transform=T.Compose([T.ToTensor(), self.test_augmentor]),
                target_transform=self.label_augment)

            self.train_data = HandwritingDataset(
                train_csv_path,
                transform=T.Compose([T.ToTensor(), self.train_augmentor]),
                target_transform=self.label_augment)
        elif self.dataset == 'MNIST':
            self.dims = (28, 28)
            if channel_pad:
                raise ValueError(
                    'needs to be fixed, channel padding does not work with mnist'
                )
            self.channels = 1
            self.n_classes = 10
            self.label_mapping = list(range(self.n_classes))
            self.label_augment = LabelAugmentor(self.label_mapping)
            data_dir = '/home/kaushikdas/aashish/pytorch_datasets'

            self.test_data = torchvision.datasets.MNIST(
                data_dir,
                train=False,
                download=True,
                transform=T.Compose([T.ToTensor(), self.test_augmentor]),
                target_transform=self.label_augment)
            self.train_data = torchvision.datasets.MNIST(
                data_dir,
                train=True,
                download=True,
                transform=T.Compose([T.ToTensor(), self.train_augmentor]),
                target_transform=self.label_augment)
        elif self.dataset in ['CIFAR10', 'CIFAR100']:
            self.dims = (3 + channel_pad, 32, 32)
            self.channels = 3 + channel_pad

            if self.dataset == 'CIFAR10':
                data_dir = 'cifar_data'
                self.n_classes = 10
                dataset_class = torchvision.datasets.CIFAR10
            else:
                data_dir = 'cifar100_data'
                self.n_classes = 100
                dataset_class = torchvision.datasets.CIFAR100

            self.label_mapping = list(range(self.n_classes))
            self.label_augment = LabelAugmentor(self.label_mapping)

            self.test_data = dataset_class(
                data_dir,
                train=False,
                download=True,
                transform=T.Compose([T.ToTensor(), self.test_augmentor]),
                target_transform=self.label_augment)
            self.train_data = dataset_class(
                data_dir,
                train=True,
                download=True,
                transform=T.Compose([
                    T.RandomHorizontalFlip(),
                    T.ColorJitter(0.1, 0.1, 0.05),
                    T.Pad(8, padding_mode='edge'),
                    T.RandomRotation(12),
                    T.CenterCrop(36),
                    T.RandomCrop(32),
                    T.ToTensor(), self.train_augmentor
                ]),
                target_transform=self.label_augment)

        else:
            raise ValueError(
                f"what is this dataset, {args['data']['dataset']}?")

        self.train_data, self.val_data = torch.utils.data.random_split(
            self.train_data, (len(self.train_data) - 1024, 1024))

        self.val_x = torch.stack([x[0] for x in self.val_data], dim=0).cuda()
        self.val_y = self.onehot(
            torch.LongTensor([x[1] for x in self.val_data]).cuda(),
            label_smoothing)

        self.train_loader = DataLoader(self.train_data,
                                       batch_size=self.batch_size,
                                       shuffle=True,
                                       num_workers=6,
                                       pin_memory=True,
                                       drop_last=True)
        self.test_loader = DataLoader(self.test_data,
                                      batch_size=self.batch_size,
                                      shuffle=False,
                                      num_workers=4,
                                      pin_memory=True,
                                      drop_last=True)
Beispiel #3
0
def main():
    global args
    args = parser.parse_args()
    print(args)

    if not os.path.exists(os.path.join(args.save_root, 'checkpoint_fitnets')):
        os.makedirs(os.path.join(args.save_root, 'checkpoint_fitnets'))

    if args.cuda:
        cudnn.benchmark = True

    print('----------- Network Initialization --------------')
    snet = define_tsnet(name=args.s_name,
                        num_class=args.num_class,
                        cuda=args.cuda)
    checkpoint = torch.load(args.s_init)
    load_pretrained_model(snet, checkpoint['net'])

    tnet = define_tsnet(name=args.t_name,
                        num_class=args.num_class,
                        cuda=args.cuda)
    checkpoint = torch.load(args.t_model)
    load_pretrained_model(tnet, checkpoint['state_dict'])
    tnet.eval()
    for param in tnet.parameters():
        param.requires_grad = False
    print('-----------------------------------------------')

    # initialize optimizer
    optimizer = torch.optim.SGD(snet.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)

    # define loss functions
    if args.cuda:
        criterionCls = torch.nn.CrossEntropyLoss().cuda()
        criterionFitnet = torch.nn.MSELoss().cuda()
    else:
        criterionCls = torch.nn.CrossEntropyLoss()
        criterionFitnet = torch.nn.MSELoss()

    # define transforms
    if args.data_name == 'cifar10':
        dataset = dst.CIFAR10
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2470, 0.2435, 0.2616)
    elif args.data_name == 'cifar100':
        dataset = dst.CIFAR100
        mean = (0.5071, 0.4865, 0.4409)
        std = (0.2673, 0.2564, 0.2762)
    else:
        raise Exception('invalid dataset name...')

    train_transform = transforms.Compose([
        transforms.Pad(4, padding_mode='reflect'),
        transforms.RandomCrop(32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])
    test_transform = transforms.Compose([
        transforms.CenterCrop(32),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])

    # define data loader
    train_loader = torch.utils.data.DataLoader(dataset(
        root=args.img_root,
        transform=train_transform,
        train=True,
        download=True),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(dataset(root=args.img_root,
                                                      transform=test_transform,
                                                      train=False,
                                                      download=True),
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=4,
                                              pin_memory=True)

    for epoch in range(1, args.epochs + 1):
        epoch_start_time = time.time()

        adjust_lr(optimizer, epoch)

        # train one epoch
        nets = {'snet': snet, 'tnet': tnet}
        criterions = {
            'criterionCls': criterionCls,
            'criterionFitnet': criterionFitnet
        }
        train(train_loader, nets, optimizer, criterions, epoch)
        epoch_time = time.time() - epoch_start_time
        print('one epoch time is {:02}h{:02}m{:02}s'.format(
            *transform_time(epoch_time)))

        # evaluate on testing set
        print('testing the models......')
        test_start_time = time.time()
        test(test_loader, nets, criterions)
        test_time = time.time() - test_start_time
        print('testing time is {:02}h{:02}m{:02}s'.format(
            *transform_time(test_time)))

        # save model
        print('saving models......')
        save_name = 'fitnet_r{}_r{}_{:>03}.pth.tar'.format(
            args.t_name[6:], args.s_name[6:], epoch)
        save_name = os.path.join(args.save_root, 'checkpoint_fitnets',
                                 save_name)
        if epoch == 1:
            save_checkpoint(
                {
                    'epoch': epoch,
                    'snet': snet.state_dict(),
                    'tnet': tnet.state_dict(),
                }, save_name)
        else:
            save_checkpoint({
                'epoch': epoch,
                'snet': snet.state_dict(),
            }, save_name)
Beispiel #4
0
        return tensor


class Normalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        return F.normalize(tensor, self.mean, self.std)


############## MNIST

data_transform = transform.Compose([
    transform.Pad(padding=2, fill=0),
    transform.ToTensor(),
    transform.Normalize(mean=[0.5], std=[0.5])
])

inver_transform_MNIST = transform.Compose([
    DeNormalize([0.5], [0.5]),
    lambda x: x.cpu().numpy() * 255.,
])
############## Read from Standard API

MNIST_root = r"./data/mnist"
MNIST_train_set = MNIST(MNIST_root,
                        train=True,
                        transform=data_transform,
                        download=True)
Beispiel #5
0
    def __init__(self, *args, **kwargs):

        self.image_transform = transforms.Pad(*args, **kwargs)
        self.padding = self.image_transform.padding
def compress(quantbits, nz, bitswap, gpu):
    # model and compression params
    zdim = 1 * 16 * 16
    zrange = torch.arange(zdim)
    xdim = 32**2 * 1
    xrange = torch.arange(xdim)
    ansbits = NORM_CONST - 1  # ANS precision
    type = torch.float64  # datatype throughout compression
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    ans_device = device  #"cuda:0"

    # set up the different channel dimension for different latent depths
    if nz == 8:
        reswidth = 61
    elif nz == 4:
        reswidth = 62
    elif nz == 2:
        reswidth = 63
    else:
        reswidth = 64
    assert nz > 0

    print(
        f"{'Bit-Swap' if bitswap else 'BB-ANS'} - MNIST - {nz} latent layers - {quantbits} bits quantization"
    )

    # seed for replicating experiment and stability
    np.random.seed(100)
    random.seed(50)
    torch.manual_seed(50)
    torch.cuda.manual_seed(50)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

    # compression experiment params
    experiments = 20
    ndatapoints = 100
    decompress = True

    # <=== MODEL ===>
    model = Model(xs=(1, 32, 32),
                  nz=nz,
                  zchannels=1,
                  nprocessing=4,
                  kernel_size=3,
                  resdepth=8,
                  reswidth=reswidth,
                  tag="batch").to(device)
    model.load_state_dict(
        torch.load(f'model/params/mnist/nz{nz}',
                   map_location=lambda storage, location: storage))
    model.eval()

    print("Discretizing")
    # get discretization bins for latent variables
    zendpoints, zcentres = discretize(nz, quantbits, type, device, model,
                                      "mnist")

    #### priors
    prior_cdfs = logistic_cdf(zendpoints[-1].t(),
                              torch.zeros(1, device=device, dtype=type),
                              torch.ones(1, device=device, dtype=type)).t()
    prior_pmfs = prior_cdfs[:, 1:] - prior_cdfs[:, :-1]
    prior_pmfs = torch.cat((prior_cdfs[:, 0].unsqueeze(1), prior_pmfs,
                            1. - prior_cdfs[:, -1].unsqueeze(1)),
                           dim=1)

    ####

    # get discretization bins for discretized logistic
    xbins = ImageBins(type, device, xdim)
    xendpoints = xbins.endpoints()
    xcentres = xbins.centres()

    print("Load data..")

    # <=== DATA ===>
    class ToInt:
        def __call__(self, pic):
            return pic * 255

    transform_ops = transforms.Compose(
        [transforms.Pad(2), transforms.ToTensor(),
         ToInt()])
    test_set = datasets.MNIST(root="model/data/mnist",
                              train=False,
                              transform=transform_ops,
                              download=True)

    # sample (experiments, ndatapoints) from test set with replacement
    print(len(test_set.data))
    if not os.path.exists("bitstreams/mnist/indices"):
        randindices = np.random.choice(len(test_set.data),
                                       size=(experiments, ndatapoints),
                                       replace=False)
        np.save("bitstreams/mnist/indices", randindices)
    else:
        randindices = np.load("bitstreams/mnist/indices")

    print("Setting up metrics..")
    # metrics for the results
    nets = np.zeros((experiments, ndatapoints), dtype=np.float)
    elbos = np.zeros((experiments, ndatapoints), dtype=np.float)
    cma = np.zeros((experiments, ndatapoints), dtype=np.float)
    total = np.zeros((experiments, ndatapoints), dtype=np.float)

    print("Compression..")
    for ei in range(experiments):
        experiment_start_time = time.time()
        print(f"Experiment {ei + 1}")
        subset = Subset(test_set, randindices[ei])
        test_loader = DataLoader(dataset=subset,
                                 batch_size=1,
                                 shuffle=False,
                                 drop_last=True)
        datapoints = list(test_loader)

        # < ===== COMPRESSION ===>
        # initialize compression
        model.compress()
        state = list(
            map(
                int,
                np.random.randint(
                    low=1 << 16,
                    high=(1 << NORM_CONST) - 1,
                    size=(200),
                    dtype=np.uint32)))  # fill state list with 'random' bits
        state[-1] = state[-1] << 16  #NORM_CONST

        states = [state.copy() for _ in range(len(datapoints))]

        initialstates = deepcopy(states)
        reststates = None

        state_init = time.time()

        iterator = tqdm(range(len(datapoints)), desc="Sender")

        # <===== SENDER =====>

        ####
        xs = []
        for xi in range(len(datapoints)):
            (x, _) = datapoints[xi]
            x = x.to(device).view(xdim)
            xs.append(x)

        for zi in range(nz):
            mus = []
            scales = []
            for xi in tqdm(range(len(datapoints))):
                input = zcentres[zi - 1, zrange,
                                 zsyms[xi]] if zi > 0 else xcentres[
                                     xrange, xs[xi].long()]
                mu, scale = model.infer(zi)(given=input)
                mus.append(mu)
                scales.append(scale)

            s = time.time()
            cdfs_b = logistic_cdf(
                torch.stack([zendpoints[zi]] * len(datapoints)).permute(
                    2, 0, 1), torch.stack(mus),
                torch.stack(scales)).permute(1, 2, 0)

            pmfs_b = torch.cat(
                (cdfs_b[:, :, 0].unsqueeze(2), cdfs_b[:, :, 1:] -
                 cdfs_b[:, :, :-1], 1. - cdfs_b[:, :, -1].unsqueeze(2)),
                dim=2)

            ans = ANS(pmfs_b.to(ans_device), bits=ansbits, quantbits=quantbits)
            t1 = time.time()
            states, zsymtops = ans.batch_decode(states)
            t2 = time.time()
            zsymtops = zsymtops.to(device)

            if zi == 0:
                reststates = states.copy()
                assert all(
                    [len(rb) > 1 for rb in reststates]
                ), "too few initial bits"  # otherwise initial state consists of too few bits

            z_dec_pmfs = []
            mus = []
            scales = []
            for zsymtop in tqdm(zsymtops):
                z = zcentres[zi, zrange, zsymtop]
                mu, scale = model.generate(zi)(given=z)
                mus.append(mu)
                scales.append(scale)

            cdfs_b = logistic_cdf(
                torch.stack([(zendpoints[zi - 1] if zi > 0 else xendpoints)] *
                            len(datapoints)).permute(2, 0, 1),
                torch.stack(mus), torch.stack(scales)).permute(1, 2, 0)

            pmfs_b = torch.cat(
                (cdfs_b[:, :, 0].unsqueeze(2), cdfs_b[:, :, 1:] -
                 cdfs_b[:, :, :-1], 1. - cdfs_b[:, :, -1].unsqueeze(2)),
                dim=2)

            ans = ANS(pmfs_b.to(ans_device), bits=ansbits, quantbits=quantbits)

            to_encode = zsyms if zi > 0 else torch.stack(xs).long()
            states = ans.batch_encode(states, to_encode)

            zsyms = zsymtops

        states = ANS(torch.stack([prior_pmfs for _ in range(len(datapoints))
                                  ]).to(ans_device),
                     bits=ansbits,
                     quantbits=quantbits).batch_encode(states, zsymtops)

        totaladdedbits_for_xs = [
            (len(state) - len(initialstate)) * 32
            for (state, initialstate) in zip(states, initialstates)
        ]

        totalbits_for_xs = [(len(state) - (len(restbits) - 1)) * 32
                            for (state, restbits) in zip(states, reststates)]

        iterator = tqdm(enumerate(zip(totaladdedbits_for_xs,
                                      totalbits_for_xs)))
        with torch.no_grad():
            for xi, (totaladdedbits, totalbits) in iterator:
                x = xs[xi]
                model.compress(False)
                logrecon, logdec, logenc, _ = model.loss(
                    x.view((-1, ) + model.xs))
                elbo = -logrecon + torch.sum(-logdec + logenc)
                model.compress(True)

                nets[ei, xi] = (totaladdedbits / xdim) - nets[ei, :xi].sum()
                elbos[ei, xi] = elbo.item() / xdim
                cma[ei, xi] = totalbits / (xdim * (xi + 1))
                total[ei, xi] = totalbits

                iterator.set_postfix_str(
                    s=
                    f"N:{nets[ei,:xi+1].mean():.2f}±{nets[ei,:xi+1].std():.2f}, D:{nets[ei,:xi+1].mean()-elbos[ei,:xi+1].mean():.4f}, C: {cma[ei,:xi+1].mean():.2f}, T: {totalbits:.0f}",
                    refresh=False)

        state_file = f"bitstreams/mnist/nz{nz}/{'Bit-Swap' if bitswap else 'BB-ANS'}/{'Bit-Swap' if bitswap else 'BB-ANS'}_{quantbits}bits_nz{nz}_experiment{ei + 1}_batch"
        print(state_file)
        # write state to file
        # print(len(states))
        # print([len(s) for s in states])

        max_common_len = min([len(s) for s in states])
        common_len = 0

        for pref in range(max_common_len):
            if len(set(s[pref] for s in states)) > 1:
                break
            common_len = pref + 1

        print("common len:", common_len)
        states_to_dump = (states[0][:common_len],
                          [s[common_len:] for s in states])
        with open(state_file, "wb") as fp:
            pickle.dump(states_to_dump, fp)

        state = None
        # open state file
        with open(state_file, "rb") as fp:
            states_prefix, states_postfixes = pickle.load(fp)
            states = [states_prefix + sp for sp in states_postfixes]

        print([len(s) for s in states])
        print(
            sum([
                len(s) - len(inits)
                for (s, inits) in zip(states, initialstates)
            ]))

        # <===== RECEIVER =====>

        # priors
        states, zsymtops = ANS(torch.stack(
            [prior_pmfs for _ in range(len(datapoints))]).to(ans_device),
                               bits=ansbits,
                               quantbits=quantbits).batch_decode(states)
        zsymtops = zsymtops.to(device)

        for zi in reversed(range(nz)):
            zs = z = zcentres[zi, zrange, zsymtops]

            z_dec_pmfs = []
            mus = []
            scales = []
            for xi in tqdm(range(len(datapoints))):

                z = zs[xi]
                mu, scale = model.generate(zi)(given=z)
                mus.append(mu)
                scales.append(scale)

            cdfs_b = logistic_cdf(
                torch.stack([(zendpoints[zi - 1] if zi > 0 else xendpoints)] *
                            len(datapoints)).permute(2, 0, 1),
                torch.stack(mus), torch.stack(scales)).permute(1, 2, 0)

            pmfs_b = torch.cat(
                (cdfs_b[:, :, 0].unsqueeze(2), cdfs_b[:, :, 1:] -
                 cdfs_b[:, :, :-1], 1. - cdfs_b[:, :, -1].unsqueeze(2)),
                dim=2)

            ans = ANS(pmfs_b.to(ans_device), bits=ansbits, quantbits=quantbits)

            states, symbols = ans.batch_decode(states)
            symbols = symbols.to(device)

            inputs = zcentres[zi - 1, zrange,
                              symbols] if zi > 0 else xcentres[xrange, symbols]

            mus = []
            scales = []

            for input in tqdm(inputs):
                mu, scale = model.infer(zi)(given=input)
                mus.append(mu)
                scales.append(scale)

            cdfs_b = logistic_cdf(
                torch.stack([zendpoints[zi]] * len(datapoints)).permute(
                    2, 0, 1), torch.stack(mus),
                torch.stack(scales)).permute(1, 2, 0)

            pmfs_b = torch.cat(
                (cdfs_b[:, :, 0].unsqueeze(2), cdfs_b[:, :, 1:] -
                 cdfs_b[:, :, :-1], 1. - cdfs_b[:, :, -1].unsqueeze(2)),
                dim=2)

            ans = ANS(pmfs_b.to(ans_device), bits=ansbits, quantbits=quantbits)

            states = ans.batch_encode(states, zsymtops)
            zsymtops = symbols

        assert all([
            torch.all(datapoints[xi][0].view(xdim).long().to(device) ==
                      zsymtops[xi].to(device)) for xi in range(len(datapoints))
        ])

        assert initialstates == states
        experiment_end_time = time.time()
        print("Experiment time", experiment_end_time - experiment_start_time)

    print(
        f"N:{nets.mean():.4f}±{nets.std():.2f}, E:{elbos.mean():.4f}±{elbos.std():.2f}, D:{nets.mean() - elbos.mean():.6f}"
    )

    # save experiments
    np.save(
        f"plots/mnist{nz}/{'bitswap' if bitswap else 'bbans'}_{quantbits}bits_nets",
        nets)
    np.save(
        f"plots/mnist{nz}/{'bitswap' if bitswap else 'bbans'}_{quantbits}bits_elbos",
        elbos)
    np.save(
        f"plots/mnist{nz}/{'bitswap' if bitswap else 'bbans'}_{quantbits}bits_cmas",
        cma)
    np.save(
        f"plots/mnist{nz}/{'bitswap' if bitswap else 'bbans'}_{quantbits}bits_total",
        total)
    for i, (video_category,
            video_name) in enumerate(zip(video_categorys, video_names)):
        train_video = readShortVideo(video_path, video_category, video_name)
        test_videos.append(train_video)
        pbar.update(1)
if mode != "test":
    print("\nloading labels...")
    for i, action_label in enumerate(action_labels):
        test_labels.append(int(action_label))
print("test_videos_len:", len(test_videos))

# extracting features
cnn_feature_extractor = Resnet50().cuda()  # to 2048 dims
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Pad((0, 40), fill=0, padding_mode='constant'),
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
cnn_feature_extractor.eval()
train_features = []
with torch.no_grad():
    print("\nextracting videos feature...")
    with tqdm(total=total_num) as pbar:
        for train_video in test_videos:
            local_batch = []
            for frame in train_video:
                frame = transform(frame)
                local_batch.append(frame)
            local_batch = torch.stack(local_batch)
Beispiel #8
0
def main():
    torch.manual_seed(0)
    device = torch.device('cuda')
    train_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('./data.cifar10', train=True, download=True,
                         transform=transforms.Compose([
                             transforms.Pad(4),
                             transforms.RandomCrop(32),
                             transforms.RandomHorizontalFlip(),
                             transforms.ToTensor(),
                             transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                         ])),
        batch_size=64, shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])),
        batch_size=200, shuffle=False)

    model = VGG(depth=19)
    model.to(device)

    # Train the base VGG-19 model
    print('=' * 10 + 'Train the unpruned base model' + '=' * 10)
    epochs = 160
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
    for epoch in range(epochs):
        if epoch in [epochs * 0.5, epochs * 0.75]:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.1
        train(model, device, train_loader, optimizer, True)
        test(model, device, test_loader)
    torch.save(model.state_dict(), 'vgg19_cifar10.pth')

    # Test base model accuracy
    print('=' * 10 + 'Test the original model' + '=' * 10)
    model.load_state_dict(torch.load('vgg19_cifar10.pth'))
    test(model, device, test_loader)
    # top1 = 93.60%

    # Pruning Configuration, in paper 'Learning efficient convolutional networks through network slimming',
    configure_list = [{
        'sparsity': 0.7,
        'op_types': ['BatchNorm2d'],
    }]

    # Prune model and test accuracy without fine tuning.
    print('=' * 10 + 'Test the pruned model before fine tune' + '=' * 10)
    pruner = SlimPruner(model, configure_list)
    model = pruner.compress()
    test(model, device, test_loader)
    # top1 = 93.55%

    # Fine tune the pruned model for 40 epochs and test accuracy
    print('=' * 10 + 'Fine tuning' + '=' * 10)
    optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
    best_top1 = 0
    for epoch in range(40):
        pruner.update_epoch(epoch)
        print('# Epoch {} #'.format(epoch))
        train(model, device, train_loader, optimizer_finetune)
        top1 = test(model, device, test_loader)
        if top1 > best_top1:
            best_top1 = top1
            # Export the best model, 'model_path' stores state_dict of the pruned model,
            # mask_path stores mask_dict of the pruned model
            pruner.export_model(model_path='pruned_vgg19_cifar10.pth', mask_path='mask_vgg19_cifar10.pth')

    # Test the exported model
    print('=' * 10 + 'Test the export pruned model after fine tune' + '=' * 10)
    new_model = VGG(depth=19)
    new_model.to(device)
    new_model.load_state_dict(torch.load('pruned_vgg19_cifar10.pth'))
    test(new_model, device, test_loader)
def main_worker(gpu, args):
    args.gpu = gpu

    if args.gpu is not None:
        logger.info(f"Use GPU: {args.gpu} for testing.")

    model = configure(args)

    if not torch.cuda.is_available():
        logger.warning("Using CPU, this will be slow.")
    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)

    cudnn.benchmark = True

    # Set eval mode.
    model.eval()

    # Get video filename.
    filename = os.path.basename(args.file)

    # Image preprocessing operation
    tensor2pil = transforms.ToPILImage()

    video_capture = cv2.VideoCapture(args.file)
    # Prepare to write the processed image into the video.
    fps = video_capture.get(cv2.CAP_PROP_FPS)
    total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
    # Set video size
    size = (int(video_capture.get(cv2.CAP_PROP_FRAME_WIDTH)),
            int(video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
    sr_size = (size[0] * args.upscale_factor, size[1] * args.upscale_factor)
    pare_size = (sr_size[0] * 2 + 10, sr_size[1] + 10 + sr_size[0] // 5 - 9)
    # Video write loader.
    sr_writer = cv2.VideoWriter(
        os.path.join("videos", f"sr_{args.upscale_factor}x_{filename}"),
        cv2.VideoWriter_fourcc(*"MPEG"), fps, sr_size)
    compare_writer = cv2.VideoWriter(
        os.path.join("videos", f"compare_{args.upscale_factor}x_{filename}"),
        cv2.VideoWriter_fourcc(*"MPEG"), fps, pare_size)

    # read frame.
    with torch.no_grad():
        success, raw_frame = video_capture.read()
        progress_bar = tqdm(
            range(total_frames),
            desc="[processing video and saving/view result videos]")
        for _ in progress_bar:
            if success:
                # Read image to tensor and transfer to the specified device for processing.
                lr = process_image(raw_frame, args.gpu)

                sr = model(lr)

                sr = sr.cpu()
                sr = sr.data[0].numpy()
                sr *= 255.0
                sr = (np.uint8(sr)).transpose((1, 2, 0))
                # save sr video
                sr_writer.write(sr)

                # make compared video and crop shot of left top\right top\center\left bottom\right bottom
                sr = tensor2pil(sr)
                # Five areas are selected as the bottom contrast map.
                crop_sr_images = transforms.FiveCrop(size=sr.width // 5 -
                                                     9)(sr)
                crop_sr_images = [
                    np.asarray(transforms.Pad(padding=(10, 5, 0, 0))(image))
                    for image in crop_sr_images
                ]
                sr = transforms.Pad(padding=(5, 0, 0, 5))(sr)
                # Five areas in the contrast map are selected as the bottom contrast map
                compare_image = transforms.Resize(
                    (sr_size[1], sr_size[0]),
                    interpolation=Mode.BICUBIC)(tensor2pil(raw_frame))
                crop_compare_images = transforms.FiveCrop(
                    size=compare_image.width // 5 - 9)(compare_image)
                crop_compare_images = [
                    np.asarray(transforms.Pad((0, 5, 10, 0))(image))
                    for image in crop_compare_images
                ]
                compare_image = transforms.Pad(padding=(0, 0, 5,
                                                        5))(compare_image)
                # concatenate all the pictures to one single picture
                # 1. Mosaic the left and right images of the video.
                top_image = np.concatenate(
                    (np.asarray(compare_image), np.asarray(sr)), axis=1)
                # 2. Mosaic the bottom left and bottom right images of the video.
                bottom_image = np.concatenate(crop_compare_images +
                                              crop_sr_images,
                                              axis=1)
                bottom_image_height = int(top_image.shape[1] /
                                          bottom_image.shape[1] *
                                          bottom_image.shape[0])
                bottom_image_width = top_image.shape[1]
                # 3. Adjust to the right size.
                bottom_image = np.asarray(
                    transforms.Resize(
                        (bottom_image_height,
                         bottom_image_width))(tensor2pil(bottom_image)))
                # 4. Combine the bottom zone with the upper zone.
                final_image = np.concatenate((top_image, bottom_image))

                # save compare video
                compare_writer.write(final_image)

                if args.view:
                    # display video
                    cv2.imshow("LR video convert SR video ", final_image)
                    if cv2.waitKey(1) & 0xFF == ord("q"):
                        break

                # next frame
                success, raw_frame = video_capture.read()
Beispiel #10
0
def main():
    transform_mnist = transforms.Compose([
        transforms.Pad(2),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1))
    ])
    transform_svhn = transforms.Compose([transforms.ToTensor()])
    mnist = tv.datasets.MNIST('./dataset/mnist',
                              download=False,
                              transform=transform_mnist,
                              train=True)
    mnist_val = tv.datasets.MNIST('./dataset/mnist',
                                  download=False,
                                  transform=transform_mnist,
                                  train=False)
    svhn = tv.datasets.SVHN('./dataset/svhn',
                            download=False,
                            transform=transform_svhn,
                            split='train')
    svhn_val = tv.datasets.SVHN('./dataset/svhn',
                                download=False,
                                transform=transform_svhn,
                                split='test')
    device = 'cuda'
    Fe = F_extractor().to(device)
    F1 = F_label().to(device)
    F2 = F_label().to(device)
    Ft = F_label().to(device)

    models = {'Fe': Fe, 'F1': F1, 'F2': F2, 'Ft': Ft}
    # print(mnist.__add__(torch.Tensor))
    print(svhn.data.shape)
    # x=torch.utils.data.TensorDataset(mnist.train_data, mnist.train_labels)
    # print(mnist.train_labels)
    batch_size = 128
    train_data = torch.utils.data.DataLoader(svhn,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=DATA_WORKERS)

    val_data = torch.utils.data.DataLoader(svhn_val,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=DATA_WORKERS)
    train_1(device,
            models,
            train_data,
            val_data,
            epochs=2,
            restore=False,
            path='phase1.pth.tar')
    train_2(device,
            models,
            svhn,
            mnist,
            mnist_val,
            step=200,
            val_step=400,
            iters=2,
            restore=False,
            path='phase2.pth.tar')
Beispiel #11
0
        delta_sample_transformer_params(encoder.transformer.transformers,
                                        encoder_out["transform_params"])

        z = pyro.sample(
            "z",
            D.Normal(encoder_out["z_mu"],
                     torch.exp(encoder_out["z_std"]) + 1e-3).to_event(1),
        )


if __name__ == "__main__":
    # pyro.enable_validation(True)
    eval_every = 2

    augmentation = tvt.Compose([
        tvt.Pad(6),
        tvt.RandomAffine(degrees=90.0,
                         translate=(0.14, 0.14),
                         scale=(0.8, 1.2)),
        tvt.ToTensor(),
    ])

    mnist = MNIST(
        "./data",
        download=True,
        transform=augmentation,
        # target_transform=target_transform
    )
    mnist_test = MNIST(
        "./data",
        download=True,
def main(args):
    save_folder = args.affix
    
    log_folder = os.path.join(args.log_root, save_folder)
    model_folder = os.path.join(args.model_root, save_folder)

    makedirs(log_folder)
    makedirs(model_folder)

    setattr(args, 'log_folder', log_folder)
    setattr(args, 'model_folder', model_folder)

    logger = create_logger(log_folder, 'train', 'info')
    print_args(args, logger)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
        torch.backends.cudnn.deterministic=True 
        torch.backends.cudnn.benchmark=False
    
    if args.model == "VGG16":
        net = vgg(dataset=args.dataset, depth=16)
        if args.mask:
            net = masked_vgg(dataset=args.dataset, depth=16)
    elif args.model == "WideResNet":
        net = WideResNet(depth=28, num_classes=args.dataset == 'cifar10' and 10 or 100, widen_factor=8)
        if args.mask:
            net = MaskedWideResNet(depth=28, num_classes=args.dataset == 'cifar10' and 10 or 100, widen_factor=8)
   
    net.to(device)
    
    trainer = Trainer(args, logger)
    
    loss = nn.CrossEntropyLoss()
 
    
    kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
    if args.dataset == 'cifar10':
        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('./data.cifar10', train=True, download=True,
                        transform=transforms.Compose([
                            transforms.Pad(4),
                            transforms.RandomCrop(32),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                        ])),
            batch_size=args.batch_size, shuffle=True, **kwargs)
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                        ])),
            batch_size=100, shuffle=True, **kwargs)
    else:
        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100('./data.cifar100', train=True, download=True,
                        transform=transforms.Compose([
                            transforms.Pad(4),
                            transforms.RandomCrop(32),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                        ])),
            batch_size=args.batch_size, shuffle=True, **kwargs)
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100('./data.cifar100', train=False, transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                        ])),
            batch_size=100, shuffle=True, **kwargs)
        
    
    optimizer = torch.optim.SGD(net.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80, 120], gamma=0.1)
    trainer.train(net, loss, device, train_loader, test_loader, optimizer=optimizer, scheduler=scheduler)
Beispiel #13
0
    def __init__(self,
                 data_dir,
                 trial,
                 transform=None,
                 colorIndex=None,
                 thermalIndex=None):
        # Load training images (path) and labels
        data_dir = '../Datasets/RegDB/'
        train_color_list = data_dir + 'idx/train_visible_{}'.format(
            trial) + '.txt'
        train_thermal_list = data_dir + 'idx/train_thermal_{}'.format(
            trial) + '.txt'

        color_img_file, train_color_label = load_data(train_color_list)
        thermal_img_file, train_thermal_label = load_data(train_thermal_list)

        train_color_image = []
        for i in range(len(color_img_file)):

            img = Image.open(data_dir + color_img_file[i])
            img = img.resize((144, 288), Image.ANTIALIAS)
            pix_array = np.array(img)
            train_color_image.append(pix_array)
        train_color_image = np.array(train_color_image)

        train_thermal_image = []
        for i in range(len(thermal_img_file)):
            img = Image.open(data_dir + thermal_img_file[i])
            img = img.resize((144, 288), Image.ANTIALIAS)
            pix_array = np.array(img)
            train_thermal_image.append(pix_array)
        train_thermal_image = np.array(train_thermal_image)

        # BGR to RGB
        self.train_color_image = train_color_image
        self.train_color_label = train_color_label

        # BGR to RGB
        self.train_thermal_image = train_thermal_image
        self.train_thermal_label = train_thermal_label

        self.transform = transform
        self.cIndex = colorIndex
        self.tIndex = thermalIndex

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        self.transform_thermal = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Pad(10),
            transforms.RandomCrop((288, 144)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(), normalize,
            ChannelRandomErasing(probability=0.5),
            ChannelAdapGray(probability=0.5)
        ])

        self.transform_color = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Pad(10),
            transforms.RandomCrop((288, 144)),
            transforms.RandomHorizontalFlip(),
            # transforms.RandomGrayscale(p = 0.1),
            transforms.ToTensor(),
            normalize,
            ChannelRandomErasing(probability=0.5)
        ])

        self.transform_color1 = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Pad(10),
            transforms.RandomCrop((288, 144)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(), normalize,
            ChannelRandomErasing(probability=0.5),
            ChannelExchange(gray=2)
        ])
    def __init__(self, json_path, img_path, set_type="train"):
        """
        Args:
            json_path (string): path to json file
            img_path (string): path to the folder where images are
            set_type: train or test , val is not supported yet
        """
        self.set_type = set_type
        self.to_tensor = transforms.ToTensor()
        self.data_info = json.loads(open(json_path).read())
        self.label_count = {x: 0 for x in range(229)}

        # read jsons
        arrr = []  # labels in one hot encoding,
        imgs = []
        names = []
        i = 0
        if set_type != "test":
            for ann in self.data_info['annotations']:
                #                if i < 10000:
                #                    continue
                imgs.append(os.path.join(img_path, ann['imageId'] + '.jpg'))
                names.append(ann['imageId'])
                label_arr = []
                label_arr = np.sum([
                    np.eye(n_classes, dtype="uint8")[int(x)]
                    for x in ann['labelId']
                ],
                                   axis=0,
                                   dtype="uint8").tolist()
                arrr.append(deepcopy(label_arr))
                for x in ann['labelId']:
                    self.label_count[int(x)] += 1
                i += 1
                if i % 100000 == 0:
                    print("Processed: " + str(i))
#                    break
        else:
            for ann in self.data_info['images']:
                imgs.append(
                    os.path.join(img_path,
                                 str(ann['imageId']) + '.jpg'))
                names.append(ann['imageId'])

        self.names = names
        self.image_arr = imgs
        # Second column is the labels

        #         self.label_arr = torch.stack(arrr)
        if set_type != "test":
            self.label_arr = arrr

        # calculate class weight
#        print(self.label_count)
        if set_type == "train":
            self.class_weight = self._get_class_weight()
            print("Class weights:{0}".format(self.class_weight))
        # Calculate len
        self.data_len = len(self.image_arr)
        self.transformations = transforms.Compose([
            transforms.Pad(8),
            transforms.CenterCrop(224),
            transforms.ToTensor()
        ])
Beispiel #15
0
def train(fine_tuning=False):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # denseNet = models.densenet161(pretrained=True)
    # for param in denseNet.parameters():
    #     param.requires_grad = False
    # net = models.resnet50(num_classes=2)

    # net = models.inception_v3(num_classes=2)
    # net.aux_logits = False

    # net = models.resnet101(pretrained=True)

    # net = MobileNetV2(n_class=3)
    net = resnet50(num_classes=3)
    index = 0
    if fine_tuning:
        index = INDEX
        # 训练数据加载
        state_dict = torch.load('D:/models/image-classify/' + net_name + '-' +
                                str(INDEX) + '.pth')
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            # name = k[21:]  # remove `module.`
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v
        # load params
        # model.load_state_dict(new_state_dict, strict=False)
        net.load_state_dict(new_state_dict)
        # for i, para in enumerate(net.parameters()):
        #     if i < 280:
        #         para.requires_grad = False
        #     else:
        #         para.requires_grad = True
        #     print(i)
    print(net)
    net = torch.nn.DataParallel(net)
    net.to(device)
    # 损失函数
    loss_fun = torch.nn.CrossEntropyLoss()
    # 优化函数
    optimizer = torch.optim.Adam(net.parameters())
    # 训练数据加载
    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(90),
        transforms.RandomAffine(15),
        transforms.ColorJitter(brightness=0.3,
                               contrast=0.3,
                               saturation=0.3,
                               hue=0.3),
        transforms.RandomGrayscale(),
        transforms.Resize(256),
        transforms.Pad(3),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225))
    ])
    dataset_train = MyDataset(root="./datalist/train.txt",
                              transform=transform_train,
                              loader=resample_loader)
    dataLoader_train = torch.utils.data.DataLoader(dataset=dataset_train,
                                                   batch_size=BATCH_SIZE,
                                                   shuffle=True,
                                                   num_workers=NUMBER_WORK)
    step_out = 500
    epochs = 500
    for epoch in range(epochs):
        net.train()
        loss_sum = 0
        acc_train_sum = 0
        loss_sum_step = 0
        acc_train_sum_step = 0
        for step_train, data in enumerate(dataLoader_train):
            b_x, b_y = data
            b_x, b_y = b_x.to(device), b_y.to(device)
            outs = net(b_x)
            # 一个为最大值,另一个为最大值的索引
            _, predicted = torch.max(outs.data, 1)
            acc_train = float(predicted.eq(b_y.data).cpu().sum()) / float(
                b_y.size(0))
            acc_train_sum += acc_train
            acc_train_sum_step += acc_train
            loss = loss_fun(outs, b_y)
            loss_sum_step += loss.cpu().data.numpy()
            loss_sum += loss.cpu().data.numpy()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (step_train + 1) % step_out == 0:
                print('epoch: %d ,step: %d, loss: %.6f, accuracy: %.6f' %
                      ((epoch + index + 1), (step_train + 1),
                       (loss_sum_step / step_out),
                       (acc_train_sum_step / step_out)))
                loss_sum_step = 0
                acc_train_sum_step = 0
        torch.save(
            net.state_dict(), 'D:/models/image-classify/' + net_name + '-' +
            str(index + epoch + 1) + '.pth')
        logger.info('epoch: %d ,trian_loss:  %.6f , train_acc: %.6f' %
                    ((epoch + index + 1), (loss_sum / float(step_train + 1)),
                     (acc_train_sum / float(step_train + 1))))
Beispiel #16
0
def plot_ranked(reverse=False, N=5):

    imsize = 150
    
    crop     = transforms.CenterCrop(imsize)
    pad      = transforms.Pad((0, 0, 1, 1), fill=0)
    totensor = transforms.ToTensor()
    normalise= transforms.Normalize((0.0031,), (0.0350,))

    transform = transforms.Compose([
        crop,
        pad,
        totensor,
        normalise,
    ])
    
    test_data = MBFRConfident('mirabest', train=False, transform=transform)
    
    path1 = 'lenet_overlap.csv'
    path2 = 'dn16_overlap.csv'
    
    df1 = pd.read_csv(path1)
    df2 = pd.read_csv(path2)

    targ1 = df1['target'].values
    p1    = df1['softmax prob'].values
    olap1 = df1['average overlap'].values
    slap1 = df1['overlap variance'].values

    targ2 = df2['target'].values
    p2    = df2['softmax prob'].values
    olap2 = df2['average overlap'].values
    slap2 = df2['overlap variance'].values
    
    p1 = [np.array(p1[i].lstrip('[').rstrip(']').split(), dtype=float) for i in range(len(targ1))]
    p2 = [np.array(p2[i].lstrip('[').rstrip(']').split(), dtype=float) for i in range(len(targ1))]
    
    p1 = np.array(p1)
    p2 = np.array(p2)
    
    diff = slap1 - slap2
    
    if reverse:
        i_ep = np.argsort(diff)[::-1]
    else:
        i_ep = np.argsort(diff)
    print(diff[i_ep[0]])
    
    from mpl_toolkits.axes_grid1 import ImageGrid

    fig = pl.figure()
    grid = ImageGrid(fig, 111,  # similar to subplot(111)
                     nrows_ncols=(1, N),  # creates 2x2 grid of axes
                     axes_pad=0.1,  # pad between axes in inch.
                     )
    j=0
    for i in i_ep[0:N]:
        subset_indices = [i] # select your indices here as a list
        subset = torch.utils.data.Subset(test_data, subset_indices)
        testloader_ordered = torch.utils.data.DataLoader(subset, batch_size=1, shuffle=False)
        data, target = iter(testloader_ordered).next()
            
        grid[j].imshow(np.squeeze(data))  # The AxesGrid object work as a list of axes.
        rectangle = pl.Circle((75,75), 25, fill=False, ec="white")
        grid[j].add_patch(rectangle)
        grid[j].text(5,15,"Source: {:2d}".format(i),{'color': 'white', 'fontsize': 10})
        grid[j].text(5,145,"True: {}, Predicted: [{},{}]".format(targ1[i],np.argmax(p1[i,:]),np.argmax(p2[i,:])),{'color': 'white', 'fontsize': 10})
        grid[j].axis('off')
        j+=1

    #grid[j].axis('off'); grid[j+1].axis('off'); grid[j+2].axis('off')

    pl.show()

    return
Beispiel #17
0
def get_data(args):
    if args.dataset == "svhn":
        transform_train = tr.Compose(
            [tr.Pad(4, padding_mode="reflect"),
             tr.RandomCrop(im_sz),
             tr.ToTensor(),
             tr.Normalize((.5, .5, .5), (.5, .5, .5)),
             lambda x: x + args.sigma * t.randn_like(x)]
        )
    else:
        transform_train = tr.Compose(
            [tr.Pad(4, padding_mode="reflect"),
             tr.RandomCrop(im_sz),
             tr.RandomHorizontalFlip(),
             tr.ToTensor(),
             tr.Normalize((.5, .5, .5), (.5, .5, .5)),
             lambda x: x + args.sigma * t.randn_like(x)]
        )
    transform_test = tr.Compose(
        [tr.ToTensor(),
         tr.Normalize((.5, .5, .5), (.5, .5, .5)),
         lambda x: x + args.sigma * t.randn_like(x)]
    )
    def dataset_fn(train, transform):
        if args.dataset == "cifar10":
            return tv.datasets.CIFAR10(root=args.data_root, transform=transform, download=True, train=train)
        elif args.dataset == "cifar100":
            return tv.datasets.CIFAR100(root=args.data_root, transform=transform, download=True, train=train)
        else:
            return tv.datasets.SVHN(root=args.data_root, transform=transform, download=True,
                                    split="train" if train else "test")

    # get all training inds
    full_train = dataset_fn(True, transform_train)
    all_inds = list(range(len(full_train)))
    # set seed
    np.random.seed(1234)
    # shuffle
    np.random.shuffle(all_inds)
    # seperate out validation set
    if args.n_valid is not None:
        valid_inds, train_inds = all_inds[:args.n_valid], all_inds[args.n_valid:]
    else:
        valid_inds, train_inds = [], all_inds
    train_inds = np.array(train_inds)
    train_labeled_inds = []
    other_inds = []
    train_labels = np.array([full_train[ind][1] for ind in train_inds])
    if args.labels_per_class > 0:
        for i in range(args.n_classes):
            print(i)
            train_labeled_inds.extend(train_inds[train_labels == i][:args.labels_per_class])
            other_inds.extend(train_inds[train_labels == i][args.labels_per_class:])
    else:
        train_labeled_inds = train_inds

    dset_train = DataSubset(
        dataset_fn(True, transform_train),
        inds=train_inds)
    dset_train_labeled = DataSubset(
        dataset_fn(True, transform_train),
        inds=train_labeled_inds)
    dset_valid = DataSubset(
        dataset_fn(True, transform_test),
        inds=valid_inds)
    dload_train = DataLoader(dset_train, batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=True)
    dload_train_labeled = DataLoader(dset_train_labeled, batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=True)
    dload_train_labeled = cycle(dload_train_labeled)
    dset_test = dataset_fn(False, transform_test)
    dload_valid = DataLoader(dset_valid, batch_size=100, shuffle=False, num_workers=4, drop_last=False)
    dload_test = DataLoader(dset_test, batch_size=100, shuffle=False, num_workers=4, drop_last=False)
    return dload_train, dload_train_labeled, dload_valid,dload_test
Beispiel #18
0
def make_dataloader(cfg):
    if cfg.DATASETS.HARD_AUG:
        train_transforms = T.Compose([
            T.Resize(cfg.INPUT.SIZE_TRAIN),
            T.RandomHorizontalFlip(p=cfg.INPUT.PROB),
            T.Pad(cfg.INPUT.PADDING),
            T.RandomCrop(cfg.INPUT.SIZE_TRAIN),
            #T.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.2),
            T.transforms.RandomAffine(0,
                                      translate=None,
                                      scale=[0.9, 1.1],
                                      shear=None,
                                      resample=False,
                                      fillcolor=128),
            T.ToTensor(),
            T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD),
            RandomErasing(probability=cfg.INPUT.RE_PROB,
                          mean=cfg.INPUT.PIXEL_MEAN)
        ])
    else:
        train_transforms = T.Compose([
            T.Resize(cfg.INPUT.SIZE_TRAIN),
            T.RandomHorizontalFlip(p=cfg.INPUT.PROB),
            T.Pad(cfg.INPUT.PADDING),
            T.RandomCrop(cfg.INPUT.SIZE_TRAIN),
            T.ToTensor(),
            T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD),
            RandomErasing(probability=cfg.INPUT.RE_PROB,
                          mean=cfg.INPUT.PIXEL_MEAN)
        ])
    val_transforms = T.Compose([
        T.Resize(cfg.INPUT.SIZE_TEST),
        T.ToTensor(),
        T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)
    ])

    val_transforms_center = T.Compose([
        T.Resize([x + 10 for x in cfg.INPUT.SIZE_TEST]),
        center_crop(256, 128),
        T.ToTensor(),
        T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)
    ])
    val_transforms_lt = T.Compose([
        T.Resize([x + 10 for x in cfg.INPUT.SIZE_TEST]),
        crop_lt(256, 128),
        T.ToTensor(),
        T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)
    ])

    val_transforms_rt = T.Compose([
        T.Resize([x + 10 for x in cfg.INPUT.SIZE_TEST]),
        crop_rt(256, 128),
        T.ToTensor(),
        T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)
    ])

    val_transforms_lb = T.Compose([
        T.Resize([x + 10 for x in cfg.INPUT.SIZE_TEST]),
        crop_lb(256, 128),
        T.ToTensor(),
        T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)
    ])
    val_transforms_rb = T.Compose([
        T.Resize([x + 10 for x in cfg.INPUT.SIZE_TEST]),
        crop_rb(256, 128),
        T.ToTensor(),
        T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)
    ])

    num_workers = cfg.DATALOADER.NUM_WORKERS

    dataset = __factory[cfg.DATASETS.NAMES](root=cfg.DATASETS.ROOT_DIR)
    num_classes = dataset.num_train_pids

    train_set = ImageDataset(dataset.train, train_transforms)

    if 'triplet' in cfg.DATALOADER.SAMPLER:
        train_loader = DataLoader(train_set,
                                  batch_size=cfg.SOLVER.IMS_PER_BATCH,
                                  sampler=RandomIdentitySampler(
                                      dataset.train, cfg.SOLVER.IMS_PER_BATCH,
                                      cfg.DATALOADER.NUM_INSTANCE),
                                  num_workers=num_workers,
                                  collate_fn=train_collate_fn)
    elif cfg.DATALOADER.SAMPLER == 'softmax':
        print('using softmax sampler')
        train_loader = DataLoader(train_set,
                                  batch_size=cfg.SOLVER.IMS_PER_BATCH,
                                  shuffle=True,
                                  num_workers=num_workers,
                                  collate_fn=train_collate_fn)
    else:
        print('unsupported sampler! expected softmax or triplet but got {}'.
              format(cfg.SAMPLER))

    #val_set = ImageDataset(dataset.query + dataset.gallery , val_transforms)
    #val_loader = DataLoader(
    #    val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers,
    #    collate_fn=val_collate_fn
    #)
    val_set_normal = ImageDataset(
        dataset.query_normal + dataset.gallery_normal, val_transforms)
    val_loader = DataLoader(val_set_normal,
                            batch_size=cfg.TEST.IMS_PER_BATCH,
                            shuffle=False,
                            num_workers=num_workers,
                            collate_fn=val_collate_fn)
    if cfg.TEST.FLIP_FEATS != 'on':
        val_set_center = ImageDataset(
            dataset.query_normal + dataset.gallery_normal,
            val_transforms_center)
        val_loader_center = DataLoader(val_set_center,
                                       batch_size=cfg.TEST.IMS_PER_BATCH,
                                       shuffle=False,
                                       num_workers=num_workers,
                                       collate_fn=val_collate_fn)
        val_set_lt = ImageDataset(
            dataset.query_normal + dataset.gallery_normal, val_transforms_lt)
        val_loader_lt = DataLoader(val_set_lt,
                                   batch_size=cfg.TEST.IMS_PER_BATCH,
                                   shuffle=False,
                                   num_workers=num_workers,
                                   collate_fn=val_collate_fn)
        val_set_rt = ImageDataset(
            dataset.query_normal + dataset.gallery_normal, val_transforms_rt)
        val_loader_rt = DataLoader(val_set_rt,
                                   batch_size=cfg.TEST.IMS_PER_BATCH,
                                   shuffle=False,
                                   num_workers=num_workers,
                                   collate_fn=val_collate_fn)
        val_set_lb = ImageDataset(
            dataset.query_normal + dataset.gallery_normal, val_transforms_lb)
        val_loader_lb = DataLoader(val_set_lb,
                                   batch_size=cfg.TEST.IMS_PER_BATCH,
                                   shuffle=False,
                                   num_workers=num_workers,
                                   collate_fn=val_collate_fn)
        val_set_rb = ImageDataset(
            dataset.query_normal + dataset.gallery_normal, val_transforms_rb)
        val_loader_rb = DataLoader(val_set_rb,
                                   batch_size=cfg.TEST.IMS_PER_BATCH,
                                   shuffle=False,
                                   num_workers=num_workers,
                                   collate_fn=val_collate_fn)

        return train_loader, val_loader, len(
            dataset.val_query
        ), num_classes, val_loader_center, val_loader_lb, val_loader_rb, val_loader_rt, val_loader_lt

    return train_loader, val_loader, len(dataset.query_normal), num_classes
def get_dataset(state, phase):
    assert phase in ('train', 'test'), 'Unsupported phase: %s' % phase
    name, root, nc, input_size, num_classes, normalization, _ = get_info(state)
    real_size = dataset_stats[name].real_size

    if name == 'MNIST':
        if input_size != real_size:
            transform_list = [
                transforms.Resize([input_size, input_size], Image.BICUBIC)
            ]
        else:
            transform_list = []
        transform_list += [
            transforms.ToTensor(),
            transforms.Normalize(*normalization),
        ]
        with suppress_stdout():
            return datasets.MNIST(root,
                                  train=(phase == 'train'),
                                  download=True,
                                  transform=transforms.Compose(transform_list))
    elif name == 'ADULT':
        if input_size != real_size:
            print("size doesn't match")
        else:
            transform_list = []
        transform_list += [
            transforms.ToTensor(),
        ]
        with suppress_stdout():
            return adult.DatasetAdult('datasets/clean_adult.csv',
                                      train=(phase == 'train'))

    elif name == 'ADULT_sex':
        if input_size != real_size:
            print("size doesn't match")
        else:
            transform_list = []
        transform_list += [
            transforms.ToTensor(),
        ]
        with suppress_stdout():
            return multiclass_adult.DatasetAdult_V2(
                'datasets/adult_multiclass.csv',
                train=(phase == 'train'),
                target=0)

    elif name == 'ADULT_race':
        if input_size != real_size:
            print("size doesn't match")
        else:
            transform_list = []
        transform_list += [
            transforms.ToTensor(),
        ]
        with suppress_stdout():
            return multiclass_adult.DatasetAdult_V2(
                'datasets/adult_multiclass.csv',
                train=(phase == 'train'),
                target=1)

    elif name == 'ADULT_workclass':
        if input_size != real_size:
            print("size doesn't match")
        else:
            transform_list = []
        transform_list += [
            transforms.ToTensor(),
        ]
        with suppress_stdout():
            return multiclass_adult.DatasetAdult_V2(
                'datasets/adult_multiclass.csv',
                train=(phase == 'train'),
                target=2)

    elif name == 'MNIST_RGB':
        transform_list = [transforms.Grayscale(3)]
        if input_size != real_size:
            transform_list.append(
                transforms.Resize([input_size, input_size], Image.BICUBIC))
        transform_list += [
            transforms.ToTensor(),
            transforms.Normalize(*normalization),
        ]
        with suppress_stdout():
            return datasets.MNIST(root,
                                  train=(phase == 'train'),
                                  download=True,
                                  transform=transforms.Compose(transform_list))
    elif name == 'USPS':
        if input_size != real_size:
            transform_list = [
                transforms.Resize([input_size, input_size], Image.BICUBIC)
            ]
        else:
            transform_list = []
        transform_list += [
            transforms.ToTensor(),
            transforms.Normalize(*normalization),
        ]
        with suppress_stdout():
            return USPS(root,
                        train=(phase == 'train'),
                        download=True,
                        transform=transforms.Compose(transform_list))
    elif name == 'SVHN':
        transform_list = []
        if input_size != real_size:
            transform_list.append(
                transforms.Resize([input_size, input_size], Image.BICUBIC))
        transform_list += [
            transforms.ToTensor(),
            transforms.Normalize(*normalization),
        ]
        with suppress_stdout():
            return datasets.SVHN(root,
                                 split=phase,
                                 download=True,
                                 transform=transforms.Compose(transform_list))
    elif name == 'Cifar10':
        transform_list = []
        if input_size != real_size:
            transform_list += [
                transforms.Resize([input_size, input_size], Image.BICUBIC),
            ]
        if phase == 'train':
            transform_list += [
                # TODO: merge the following into the padding options of
                #       RandomCrop when a new torchvision version is released.
                transforms.Pad(padding=4, padding_mode='reflect'),
                transforms.RandomCrop(input_size),
                transforms.RandomHorizontalFlip(),
            ]
        transform_list += [
            transforms.ToTensor(),
            transforms.Normalize(*normalization),
        ]
        with suppress_stdout():
            return datasets.CIFAR10(root,
                                    phase == 'train',
                                    transforms.Compose(transform_list),
                                    download=True)
    elif name == 'CUB200':
        transform_list = []
        if phase == 'train':
            transform_list += [
                transforms.RandomResizedCrop(input_size,
                                             interpolation=Image.BICUBIC),
                transforms.RandomHorizontalFlip(),
            ]
        else:
            transform_list += [
                transforms.Resize([input_size, input_size], Image.BICUBIC),
            ]
        transform_list += [
            transforms.ToTensor(),
            transforms.Normalize(*normalization),
        ]
        return caltech_ucsd_birds.CUB200(root,
                                         phase == 'train',
                                         transforms.Compose(transform_list),
                                         download=True)
    elif name == 'PASCAL_VOC':
        transform_list = []
        if phase == 'train':
            transform_list += [
                transforms.RandomResizedCrop(input_size,
                                             interpolation=Image.BICUBIC),
                transforms.RandomHorizontalFlip(),
            ]
        else:
            transform_list += [
                transforms.Resize([input_size, input_size], Image.BICUBIC),
            ]
        transform_list += [
            transforms.ToTensor(),
            transforms.Normalize(*normalization),
        ]
        if phase == 'train':
            phase = 'trainval'
        return pascal_voc.PASCALVoc2007(root, phase,
                                        transforms.Compose(transform_list))

    else:
        raise ValueError('Unsupported dataset: %s' % state.dataset)
Beispiel #20
0
cifar_transform_train = transforms.Compose([
    transforms.Resize(32),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
cifar_transform_test = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

mnist_transform = transforms.Compose([
    transforms.Pad(2),  # pad to 32, 32
    transforms.ToTensor(),
    transforms.Normalize((0.1305, ), (0.3081, ))
])

fmnist_transform = transforms.Compose([
    transforms.Pad(2),  # pad to 32, 32
    transforms.ToTensor(),
    transforms.Normalize((0.2859, ), (0.3530, ))
])


class PreResNet164:
    base = PreResNet
    args = list()
    kwargs = {"depth": 164}
    def __init__(self,
                 config,
                 config_path,
                 name,
                 dataset,
                 split,
                 model,
                 pretrain: bool,
                 optimizer,
                 lr: float = DEFAULT_LR,
                 momentum: float = DEFAULT_MOMENTUM,
                 weight_decay: float = DEFAULT_WEIGHT_DECAY,
                 start_epoch: int = BaseTrainer.DEFAULT_START_EPOCH):
        super(Trainer, self).__init__(name, dataset, split, model, optimizer,
                                      start_epoch)

        if not lr > 0:
            raise ValueError(
                value_error_msg('lr', lr, 'lr > 0', Trainer.DEFAULT_LR))

        if not momentum >= 0:
            raise ValueError(
                value_error_msg('momentum', momentum, 'momentum >= 0',
                                Trainer.DEFAULT_MOMENTUM))

        if not weight_decay >= 0:
            raise ValueError(
                value_error_msg('weight_decay', weight_decay,
                                'weight_decay >= 0', Trainer.DEFAULT_MOMENTUM))

        self.config = config
        self.model_path = format_path(
            self.config[self.name.value]['model_format'], self.name.value,
            self.config['Default']['delimiter'])

        if self.split == Trainer.Split.TRAIN_VAL:
            self.phase = ['train', 'val']
        elif self.split == Trainer.Split.TRAIN_ONLY:
            self.phase = ['train']
        else:
            raise ValueError(
                value_error_msg('split', split, BaseTrainer.SPLIT_LIST))

        if self.name == Trainer.Name.MARKET1501:
            transform_train_list = [
                # transforms.RandomResizedCrop(size=128, scale=(0.75,1.0), ratio=(0.75,1.3333), interpolation=3),
                transforms.Resize((256, 128), interpolation=3),
                transforms.Pad(10),
                transforms.RandomCrop((256, 128)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
            ]

            transform_val_list = [
                transforms.Resize(size=(256, 128),
                                  interpolation=3),  # Image.BICUBIC
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
            ]

            data_transforms = {
                'train': transforms.Compose(transform_train_list),
                'val': transforms.Compose(transform_val_list),
            }
        else:
            raise ValueError(
                value_error_msg('name', self.name, Trainer.NAME_LIST))

        # dataset declaration
        self.dataset = {}
        self.dataset_sizes = {}

        # dataset loading
        for phase in self.phase:
            folder_name = phase
            if self.split == Trainer.Split.TRAIN_ONLY:
                folder_name = 'total_' + folder_name
            self.dataset[phase] = ImageFolder(
                join(self.config[self.name.value]['dataset_dir'], folder_name),
                data_transforms[phase])
            self.dataset_sizes[phase] = len(self.dataset[phase])

        # record train_class num on setting files
        model_name = self.model.value
        train_class = len(self.dataset['train'].classes)
        config[self.name.value]['train_class'] = str(train_class)
        with open(config_path, 'w+') as file:
            config.write(file)

        # initialize model weights
        if self.model == Trainer.Model.RESNET50:
            self.model = ResNet50(self.config,
                                  train_class,
                                  pretrained=pretrain)
            if self.start_epoch > 0:
                load_model(
                    self.model, self.config[self.name.value]['model_format'] %
                    (model_name, self.start_epoch))
        # else:
        #     raise ValueError(value_error_msg('model', model, Trainer.MODEL_LIST))

        self.suffix = 'pretrain' if pretrain else 'no_pretrain'
        self.train_path = self.config[
            self.name.value]['train_path'] % self.suffix

        # use different settings for different params in model when using optimizers
        ignored_params = list(map(id, self.model.final_block.parameters()))
        base_params = filter(lambda p: id(p) not in ignored_params,
                             self.model.parameters())

        if self.optimizer == BaseTrainer.Optimizer.SGD:
            self.optimizer = optim.SGD(
                [{
                    'params': base_params,
                    'lr': 0.1 * lr
                }, {
                    'params': self.model.final_block.parameters(),
                    'lr': lr
                }],
                weight_decay=weight_decay,
                momentum=momentum,
                nesterov=True)
        # else:
        #     raise ValueError(value_error_msg('optimizer', optimizer, Trainer.OPTIMIZER_LIST))
        self.criterion = nn.CrossEntropyLoss()
Beispiel #22
0
    def run(self):
        # Set eval model.
        self.model.eval()

        # read frame
        success, raw_frame = self.video_capture.read()
        progress_bar = tqdm(
            range(self.total_frames),
            desc="[processing video and saving/view result videos]")
        for _ in progress_bar:
            if success:
                # Read img to tensor and transfer to the specified device for processing.
                img = Image.open(self.args.lr)
                lr = process_image(img, self.device)

                sr = inference(self.model, lr)

                sr = sr.cpu()
                sr = sr.data[0].numpy()
                sr *= 255.0
                sr = (np.uint8(sr)).transpose((1, 2, 0))
                # save sr video
                self.sr_writer.write(sr)

                # make compared video and crop shot of left top\right top\center\left bottom\right bottom
                sr = self.tensor2pil(sr)
                # Five areas are selected as the bottom contrast map.
                crop_sr_imgs = transforms.FiveCrop(size=sr.width // 5 - 9)(sr)
                crop_sr_imgs = [
                    np.asarray(transforms.Pad(padding=(10, 5, 0, 0))(img))
                    for img in crop_sr_imgs
                ]
                sr = transforms.Pad(padding=(5, 0, 0, 5))(sr)
                # Five areas in the contrast map are selected as the bottom contrast map
                compare_img = transforms.Resize(
                    (self.sr_size[1], self.sr_size[0]),
                    interpolation=Image.BICUBIC)(self.tensor2pil(raw_frame))
                crop_compare_imgs = transforms.FiveCrop(
                    size=compare_img.width // 5 - 9)(compare_img)
                crop_compare_imgs = [
                    np.asarray(transforms.Pad(padding=(0, 5, 10, 0))(img))
                    for img in crop_compare_imgs
                ]
                compare_img = transforms.Pad(padding=(0, 0, 5, 5))(compare_img)
                # concatenate all the pictures to one single picture
                # 1. Mosaic the left and right images of the video.
                top_img = np.concatenate(
                    (np.asarray(compare_img), np.asarray(sr)), axis=1)
                # 2. Mosaic the bottom left and bottom right images of the video.
                bottom_img = np.concatenate(crop_compare_imgs + crop_sr_imgs,
                                            axis=1)
                bottom_img_height = int(top_img.shape[1] /
                                        bottom_img.shape[1] *
                                        bottom_img.shape[0])
                bottom_img_width = top_img.shape[1]
                # 3. Adjust to the right size.
                bottom_img = np.asarray(
                    transforms.Resize(
                        (bottom_img_height,
                         bottom_img_width))(self.tensor2pil(bottom_img)))
                # 4. Combine the bottom zone with the upper zone.
                final_image = np.concatenate((top_img, bottom_img))

                # save compare video
                self.compare_writer.write(final_image)

                if self.args.view:
                    # display video
                    cv2.imshow("LR video convert HR video ", final_image)
                    if cv2.waitKey(1) & 0xFF == ord("q"):
                        break

                # next frame
                success, raw_frame = self.video_capture.read()
Beispiel #23
0
                                          transform=transform_test)
C_test = torch.utils.data.DataLoader(C_testset,
                                     batch_size=100,
                                     shuffle=False,
                                     num_workers=1,
                                     drop_last=True)

m_transform_train = transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.1307, ), (0.3081, ))
])

m_transform_test = transforms.Compose([
    transforms.Pad(2),
    transforms.ToTensor(),
    transforms.Normalize((0.1307, ), (0.3081, )),
])

fashion_train = mnist_fashion_reader.MNIST_FISHION(root='../../data/fashion',
                                                   train=True,
                                                   download=False,
                                                   transform=m_transform_train)
fashion_test = mnist_fashion_reader.MNIST_FISHION(root='../../data/fashion',
                                                  train=False,
                                                  download=False,
                                                  transform=m_transform_test)

f_train = torch.utils.data.DataLoader(fashion_train,
                                      batch_size=batch_size,
Beispiel #24
0
    if gid >= 0:
        gpu_ids.append(gid)

# set gpu ids
if len(gpu_ids) > 0:
    torch.cuda.set_device(gpu_ids[0])
    cudnn.benchmark = True
######################################################################
# Load Data
# ---------
#

transform_train_list = [
    #transforms.RandomResizedCrop(size=128, scale=(0.75,1.0), ratio=(0.75,1.3333), interpolation=3), #Image.BICUBIC)
    transforms.Resize((256, 128), interpolation=3),
    transforms.Pad(10),
    transforms.RandomCrop((256, 128)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]

transform_val_list = [
    transforms.Resize(size=(256, 128), interpolation=3),  #Image.BICUBIC
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]

if opt.PCB:
    transform_train_list = [
        transforms.Resize((384, 192), interpolation=3),
Beispiel #25
0
def get_vehicle_dataloader(cfg, quick_check=False):

    source = globals()[cfg['SOURCE']]()
    target = globals()[cfg['TARGET']]()
    if quick_check:
        source_train = source.train[:1000]
        target_train = target.train[:1000]
    else:
        source_train = source.train
        target_train = target.train

    target_test = target.test
    query = target.query
    gallery = target.gallery
    num_gpus = torch.cuda.device_count()
    normalizer = T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

    train_transformer = T.Compose([
        T.Resize((cfg['WIDTH'], cfg['HEIGHT'])),
        T.RandomHorizontalFlip(p=0.5),
        T.Pad(0),
        T.RandomCrop((cfg['WIDTH'],cfg['HEIGHT'])),
        T.ToTensor(),
        normalizer,
        RandomErasing(),
    ])

    test_transformer = T.Compose([
        T.Resize((cfg['WIDTH'], cfg['HEIGHT'])),
        T.ToTensor(),
        normalizer,
    ])
    source_loader = DataLoader(
        Preprocessor(source_train, name=cfg['SOURCE'], training=True, transform=train_transformer),
        sampler=RandomIdentitySampler(source_train, cfg['BATCHSIZE'], cfg['INSTANCE'], cfg['SOURCE']),
        batch_size=cfg['BATCHSIZE'],
        num_workers=4,
        pin_memory=True,
    )
    target_loader = DataLoader(
        Preprocessor(target_train, name=cfg['TARGET'], training=True, transform=train_transformer),
        batch_size=cfg['BATCHSIZE'],
        num_workers=4,
        #shuffle=True,
        sampler=RandomIdentitySampler(target_train, cfg['BATCHSIZE'], cfg['INSTANCE'], cfg['TARGET']),
        pin_memory=True,
        drop_last=True,
    )
    test_loader = DataLoader(
        Preprocessor(target_test, name=cfg['TARGET'], training=False, transform=test_transformer),
        batch_size=cfg['BATCHSIZE'],
        num_workers=4,
        shuffle=False,
        pin_memory=True,
    )
    target_cluster_loader = DataLoader(
        Preprocessor(target_train, name=cfg['TARGET'], training=True, transform=test_transformer),
        batch_size=cfg['BATCHSIZE'],
        num_workers=4,
        shuffle=False,
        pin_memory=True,
    )

    return source_loader, target_loader, test_loader, query, gallery, train_transformer, source_train, target_train, target_cluster_loader
Beispiel #26
0
def get_dataset(state, phase):
    dataset_stats['imdb'] = DatasetStats(1, state.maxlen, 2)
    dataset_stats['sst5'] = DatasetStats(1, state.maxlen, 5)
    dataset_stats['trec6'] = DatasetStats(1, state.maxlen, 6)
    dataset_stats['trec50'] = DatasetStats(1, state.maxlen, 50)
    dataset_stats['snli'] = DatasetStats(1, state.maxlen, 3)
    dataset_stats['multinli'] = DatasetStats(1, state.maxlen, 3)
    assert phase in ('train', 'test'), 'Unsupported phase: %s' % phase
    name, root, nc, input_size, num_classes, normalization, _ = get_info(state)
    real_size = dataset_stats[name].real_size

    if name == 'MNIST':
        if input_size != real_size:
            transform_list = [
                transforms.Resize([input_size, input_size], Image.BICUBIC)
            ]
        else:
            transform_list = []
        transform_list += [
            transforms.ToTensor(),
            transforms.Normalize(*normalization),
        ]
        with suppress_stdout():
            return datasets.MNIST(root,
                                  train=(phase == 'train'),
                                  download=True,
                                  transform=transforms.Compose(transform_list))
    elif name == 'MNIST_RGB':
        transform_list = [transforms.Grayscale(3)]
        if input_size != real_size:
            transform_list.append(
                transforms.Resize([input_size, input_size], Image.BICUBIC))
        transform_list += [
            transforms.ToTensor(),
            transforms.Normalize(*normalization),
        ]
        with suppress_stdout():
            return datasets.MNIST(root,
                                  train=(phase == 'train'),
                                  download=True,
                                  transform=transforms.Compose(transform_list))
    elif name == 'USPS':
        if input_size != real_size:
            transform_list = [
                transforms.Resize([input_size, input_size], Image.BICUBIC)
            ]
        else:
            transform_list = []
        transform_list += [
            transforms.ToTensor(),
            transforms.Normalize(*normalization),
        ]
        with suppress_stdout():
            return USPS(root,
                        train=(phase == 'train'),
                        download=True,
                        transform=transforms.Compose(transform_list))
    elif name == 'SVHN':
        transform_list = []
        if input_size != real_size:
            transform_list.append(
                transforms.Resize([input_size, input_size], Image.BICUBIC))
        transform_list += [
            transforms.ToTensor(),
            transforms.Normalize(*normalization),
        ]
        with suppress_stdout():
            return datasets.SVHN(root,
                                 split=phase,
                                 download=True,
                                 transform=transforms.Compose(transform_list))
    elif name == 'Cifar10':
        transform_list = []
        if input_size != real_size:
            transform_list += [
                transforms.Resize([input_size, input_size], Image.BICUBIC),
            ]
        if phase == 'train':
            transform_list += [
                # TODO: merge the following into the padding options of
                #       RandomCrop when a new torchvision version is released.
                transforms.Pad(padding=4, padding_mode='reflect'),
                transforms.RandomCrop(input_size),
                transforms.RandomHorizontalFlip(),
            ]
        transform_list += [
            transforms.ToTensor(),
            transforms.Normalize(*normalization),
        ]
        with suppress_stdout():
            return datasets.CIFAR10(root,
                                    phase == 'train',
                                    transforms.Compose(transform_list),
                                    download=True)
    elif name == 'CUB200':
        transform_list = []
        if phase == 'train':
            transform_list += [
                transforms.RandomResizedCrop(input_size,
                                             interpolation=Image.BICUBIC),
                transforms.RandomHorizontalFlip(),
            ]
        else:
            transform_list += [
                transforms.Resize([input_size, input_size], Image.BICUBIC),
            ]
        transform_list += [
            transforms.ToTensor(),
            transforms.Normalize(*normalization),
        ]
        return caltech_ucsd_birds.CUB200(root,
                                         phase == 'train',
                                         transforms.Compose(transform_list),
                                         download=True)
    elif name == 'PASCAL_VOC':
        transform_list = []
        if phase == 'train':
            transform_list += [
                transforms.RandomResizedCrop(input_size,
                                             interpolation=Image.BICUBIC),
                transforms.RandomHorizontalFlip(),
            ]
        else:
            transform_list += [
                transforms.Resize([input_size, input_size], Image.BICUBIC),
            ]
        transform_list += [
            transforms.ToTensor(),
            transforms.Normalize(*normalization),
        ]
        if phase == 'train':
            phase = 'trainval'
        return pascal_voc.PASCALVoc2007(root, phase,
                                        transforms.Compose(transform_list))
    elif name == 'imdb':
        transform_list = []
        # set up fields
        TEXT = data.Field(lower=True,
                          include_lengths=True,
                          batch_first=True,
                          fix_length=state.maxlen)
        LABEL = data.LabelField(dtype=torch.long)

        # make splits for data
        train, test = textdata.IMDB.splits(TEXT, LABEL)
        # build the vocabulary
        TEXT.build_vocab(train,
                         vectors=GloVe(name='6B',
                                       dim=state.ninp,
                                       max_vectors=state.ntoken),
                         max_size=state.ntoken - 2)  #max_size=state.ntoken,
        LABEL.build_vocab(train)
        state.pretrained_vec = TEXT.vocab.vectors
        state.glove = TEXT.vocab
        #man=TEXT.vocab.vectors[TEXT.vocab["man"]].clone()
        #woman=TEXT.vocab.vectors[TEXT.vocab["woman"]].clone()
        #king=TEXT.vocab.vectors[TEXT.vocab["doctor"]].clone()

        #print(torch.norm(king - man + woman))
        #vec = king - man + woman
        #print_closest_words(vec, TEXT.vocab)
        #print_closest_words(king, TEXT.vocab)
        #print(TEXT.vocab.vectors)
        #ninp=32 #Maybe 400
        #ntoken=32
        #encoder = nn.Embedding(ntoken, ninp)

        #train_iter, test_iter = textdata.IMDB.iters(batch_size=state.batch_size, fix_length=state.ninp)
        if phase == "train":
            src = train
            #src = encoder(train_iter) * math.sqrt(ninp)
        else:
            src = test
            #src = encoder(test_iter) * math.sqrt(ninp)

        #src = data.Iterator.splits(
        #src, batch_size=state.batch_size, device=state.device, repeat=False, sort_key=lambda x: len(x.src))

        return src
    elif name == 'sst5':
        transform_list = []
        # set up fields
        TEXT = data.Field(lower=True,
                          include_lengths=True,
                          batch_first=True,
                          fix_length=state.maxlen)
        LABEL = data.LabelField(dtype=torch.long)

        # make splits for data
        train, valid, test = textdata.SST.splits(TEXT,
                                                 LABEL,
                                                 fine_grained=True)
        # build the vocabulary
        TEXT.build_vocab(train,
                         vectors=GloVe(name='6B',
                                       dim=state.ninp,
                                       max_vectors=state.ntoken),
                         max_size=state.ntoken - 2)  #max_size=state.ntoken,
        LABEL.build_vocab(train)
        #print(len(TEXT.vocab))
        #print(len(LABEL.vocab))
        state.pretrained_vec = TEXT.vocab.vectors
        state.glove = TEXT.vocab
        #ninp=32 #Maybe 400
        #ntoken=32
        #encoder = nn.Embedding(ntoken, ninp)

        #train_iter, test_iter = textdata.IMDB.iters(batch_size=state.batch_size, fix_length=state.ninp)
        if phase == "train":
            src = train
            #src = encoder(train_iter) * math.sqrt(ninp)
        else:
            src = test
            #src = encoder(test_iter) * math.sqrt(ninp)

        #src = data.Iterator.splits(
        #src, batch_size=state.batch_size, device=state.device, repeat=False, sort_key=lambda x: len(x.src))

        return src
    elif name == 'trec6':
        transform_list = []
        # set up fields
        TEXT = data.Field(lower=True,
                          include_lengths=True,
                          batch_first=True,
                          fix_length=state.maxlen)
        LABEL = data.LabelField(dtype=torch.long)

        # make splits for data
        train, test = textdata.TREC.splits(TEXT, LABEL, fine_grained=False)
        # build the vocabulary
        TEXT.build_vocab(train,
                         vectors=GloVe(name='6B',
                                       dim=state.ninp,
                                       max_vectors=state.ntoken),
                         max_size=state.ntoken - 2)  #max_size=state.ntoken,
        LABEL.build_vocab(train)
        #print(len(TEXT.vocab))
        #print(len(LABEL.vocab))
        state.pretrained_vec = TEXT.vocab.vectors
        state.glove = TEXT.vocab
        #ninp=32 #Maybe 400
        #ntoken=32
        #encoder = nn.Embedding(ntoken, ninp)

        #train_iter, test_iter = textdata.IMDB.iters(batch_size=state.batch_size, fix_length=state.ninp)
        if phase == "train":
            src = train
            #src = encoder(train_iter) * math.sqrt(ninp)
        else:
            src = test
            #src = encoder(test_iter) * math.sqrt(ninp)

        #src = data.Iterator.splits(
        #src, batch_size=state.batch_size, device=state.device, repeat=False, sort_key=lambda x: len(x.src))

        return src
    elif name == 'trec50':
        transform_list = []
        # set up fields
        TEXT = data.Field(lower=True,
                          include_lengths=True,
                          batch_first=True,
                          fix_length=state.maxlen)
        LABEL = data.LabelField(dtype=torch.long)

        # make splits for data
        train, test = textdata.TREC.splits(TEXT, LABEL, fine_grained=True)
        # build the vocabulary
        TEXT.build_vocab(train,
                         vectors=GloVe(name='6B',
                                       dim=state.ninp,
                                       max_vectors=state.ntoken),
                         max_size=state.ntoken - 2)  #max_size=state.ntoken,
        LABEL.build_vocab(train)
        #print(len(TEXT.vocab))
        #print(len(LABEL.vocab))
        state.pretrained_vec = TEXT.vocab.vectors
        state.glove = TEXT.vocab
        #ninp=32 #Maybe 400
        #ntoken=32
        #encoder = nn.Embedding(ntoken, ninp)

        #train_iter, test_iter = textdata.IMDB.iters(batch_size=state.batch_size, fix_length=state.ninp)
        if phase == "train":
            src = train
            #src = encoder(train_iter) * math.sqrt(ninp)
        else:
            src = test
            #src = encoder(test_iter) * math.sqrt(ninp)

        #src = data.Iterator.splits(
        #src, batch_size=state.batch_size, device=state.device, repeat=False, sort_key=lambda x: len(x.src))

        return src
    elif name == 'snli':
        transform_list = []
        # set up fields
        TEXT = data.Field(lower=True,
                          include_lengths=True,
                          batch_first=True,
                          fix_length=state.maxlen)
        LABEL = data.LabelField(dtype=torch.long)

        # make splits for data
        train, valid, test = textdata.SNLI.splits(TEXT, LABEL)
        # build the vocabulary
        TEXT.build_vocab(train,
                         vectors=GloVe(name='6B',
                                       dim=state.ninp,
                                       max_vectors=state.ntoken),
                         max_size=state.ntoken - 2)  #max_size=state.ntoken,
        LABEL.build_vocab(train)
        #print(len(TEXT.vocab))
        #print(len(LABEL.vocab))
        state.pretrained_vec = TEXT.vocab.vectors
        state.glove = TEXT.vocab
        #ninp=32 #Maybe 400
        #ntoken=32
        #encoder = nn.Embedding(ntoken, ninp)

        #train_iter, test_iter = textdata.IMDB.iters(batch_size=state.batch_size, fix_length=state.ninp)
        if phase == "train":
            src = train
            #src = encoder(train_iter) * math.sqrt(ninp)
        else:
            src = test
            #src = encoder(test_iter) * math.sqrt(ninp)

        #src = data.Iterator.splits(
        #src, batch_size=state.batch_size, device=state.device, repeat=False, sort_key=lambda x: len(x.src))

        return src
    elif name == 'multinli':
        transform_list = []
        # set up fields
        TEXT = data.Field(lower=True,
                          include_lengths=True,
                          batch_first=True,
                          fix_length=state.maxlen)
        LABEL = data.LabelField(dtype=torch.long)

        # make splits for data
        train, valid, test = textdata.MultiNLI.splits(TEXT, LABEL)
        # build the vocabulary
        TEXT.build_vocab(train,
                         vectors=GloVe(name='6B',
                                       dim=state.ninp,
                                       max_vectors=state.ntoken),
                         max_size=state.ntoken - 2)  #max_size=state.ntoken,
        LABEL.build_vocab(train)
        #print(len(TEXT.vocab))
        #print(len(LABEL.vocab))
        state.pretrained_vec = TEXT.vocab.vectors
        state.glove = TEXT.vocab
        #ninp=32 #Maybe 400
        #ntoken=32
        #encoder = nn.Embedding(ntoken, ninp)

        #train_iter, test_iter = textdata.IMDB.iters(batch_size=state.batch_size, fix_length=state.ninp)
        if phase == "train":
            src = train
            #src = encoder(train_iter) * math.sqrt(ninp)
        else:
            src = test
            #src = encoder(test_iter) * math.sqrt(ninp)

        #src = data.Iterator.splits(
        #src, batch_size=state.batch_size, device=state.device, repeat=False, sort_key=lambda x: len(x.src))

        return src

    else:
        raise ValueError('Unsupported dataset: %s' % state.dataset)
Beispiel #27
0
                        im_size=im_size,
                        epc_seed=epc_seed)
        dataset_sizes = {'train': 5e4, 'test': 1e4}
    else:
        raise Exception('Should not have reached here')

    if args.augment:
        train_transform = transforms.Compose([
            transforms.RandomCrop(config.padded_im_size, padding=2),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
    else:
        train_transform = transforms.Compose([
            transforms.Pad(int((config.padded_im_size - config.im_size) / 2)),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
    test_transform = transforms.Compose([
        transforms.Pad((config.padded_im_size - config.im_size) // 2),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    if args.dataset == 'MNIST':
        train_data = datasets.MNIST(osp.join(args.dataset_root, 'MNIST'),
                                    train=True,
                                    transform=train_transform,
                                    download=True)
        test_data = datasets.MNIST(osp.join(args.dataset_root, 'MNIST'),
# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()

print(images.shape)


# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

toPIL = transforms.ToPILImage()
toTensor = transforms.ToTensor()
pad = transforms.Pad(padding=10,fill=0)

ss=  torchvision.utils.make_grid(images)
aaa = ss / 2 + 0.5  # unnormalize
batch_imgs=aaa.numpy()

img_list=[]
for i in range(batch_imgs.shape[0]):
    n_img = batch_imgs[i]
    n_img =n_img.astype(np.uint8)
    n_img = np.transpose(n_img,[1,2,0])
    plt.imshow(n_img)
    plt.show()
    images2 = toPIL(n_img)
    pad_img = pad(images2)
    print("pad size", pad_img.size)
Beispiel #29
0
    metavar='N',
    help='how many batches to wait before logging training status')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
    './data',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.Pad(4),
        transforms.RandomCrop(32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])),
                                           batch_size=args.batch_size,
                                           shuffle=True,
                                           **kwargs)
test_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
    './data',
    train=False,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])),
Beispiel #30
0
    def __getitem__(self, index):
        """
        Niko Dataset Get Item
        @param index: index
        Returns:
            if self.color_histogram
            tuple: (imageA == original, imageB == sketch, colors)
            else:
            tuple: (imageA == original, imageB == sketch)

            if self.resize
            resized image will be appended end of the above tuple
        """
        filename = self.image_files[index]
        file_id = filename.split('/')[-1][:-4]

        if self.color_histogram:
            # build colorgram tensor
            color_info = self.color_cache.get(file_id, None)
            if color_info is None:
                with open(
                        os.path.join('./data/colorgram',
                                     '%s.json' % file_id).replace('\\', '/'),
                        'r') as json_file:
                    # load color info dictionary from json file
                    color_info = json.loads(json_file.read())
                    self.color_cache[file_id] = color_info
            colors = make_colorgram_tensor(color_info)

        image = Image.open(filename)
        image_width, image_height = image.size
        imageA = image.crop((0, 0, image_width // 2, image_height))
        imageB = image.crop((image_width // 2, 0, image_width, image_height))

        # default transforms, pad if needed and center crop 256
        width_pad = self.size - image_width // 2
        if width_pad < 0:
            # do not pad
            width_pad = 0

        height_pad = self.size - image_height
        if height_pad < 0:
            height_pad = 0

        # padding as white
        padding = transforms.Pad((width_pad // 2, height_pad // 2 + 1,
                                  width_pad // 2 + 1, height_pad // 2),
                                 (255, 255, 255))

        # use center crop
        crop = transforms.CenterCrop(self.size)

        imageA = padding(imageA)
        imageA = crop(imageA)

        imageB = padding(imageB)
        imageB = crop(imageB)

        if self.transform is not None:
            imageA = self.transform(imageA)
            imageB = self.transform(imageB)

        # scale image into range [-1, 1]
        imageA = scale(imageA)
        imageB = scale(imageB)
        if not self.color_histogram:
            return imageA, imageB
        else:
            return imageA, imageB, colors