Example #1
0
def load_cifar10_class_subsets(classes,
                               train=True,
                               device="cpu",
                               dtype=torch.float,
                               minmax=(0, 1.0)):
    """
    Load and return a subset of the CIFAR10 dataset, including only the
    specified classes.

    - classes: a list of class indices (as returned by label_class_list()
    - minmax: transform the range of pixel intensity to satisfy this range.
    """
    trdata = load_cifar10_dataset(train)
    # ntr = trdata.train_data.shape[0]

    # Pick out the specified classes
    # numpy arrays
    X = trdata.train_data
    Y = np.array(trdata.train_labels)

    # filter data according to the chosen classes
    tr_inds = [Y[i] in classes for i in range(len(Y))]
    Xtr = X[tr_inds]
    Ytr = Y[tr_inds]
    Xtr = util.linear_range_transform(Xtr, (0, 255), minmax)
    Ytr = util.linear_range_transform(Ytr, (0, 255), minmax)

    TXtr = torch.tensor(Xtr.transpose(0, 3, 1, 2), device=device, dtype=dtype)
    TYtr = torch.tensor(Ytr, device=device, dtype=dtype)
    Tr = torch.utils.data.TensorDataset(TXtr, TYtr)
    return Tr
Example #2
0
    def test_linear_range_transform(self):
        A = np.array([0.1, 0.8])
        testing.assert_almost_equal(
            np.array([1, 8]), util.linear_range_transform(A, (0, 1), (0, 10)))

        testing.assert_almost_equal(
            np.array([-2, 3]),
            util.linear_range_transform(np.array([-0.2, 0.3]), (-1, 1),
                                        (-10, 10)))
Example #3
0
    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)

        # tanh's range is [-1, 1]. Rescale the range to match what we want.
        minmax = self.minmax
        img = util.linear_range_transform(img, from_range=(-1, 1), to_range=minmax)
        return img
Example #4
0
 def forward(self, z):
     z = z.view(-1, self.latent_dim, 1, 1)
     img = self.main(z)
     # print('output size: {}'.format(img.shape))
     # tanh's range is [-1, 1]. Rescale the range to match what we want.
     minmax = self.minmax
     img = util.linear_range_transform(img, from_range=(-1, 1), to_range=minmax)
     # print('output size: {}'.format(img.shape))
     return img
Example #5
0
 def forward(self, z):
     out = self.net(z)
     img = out.view(out.shape[0], self.channels, 32, 32)
     # tanh's range is [-1, 1]. Rescale the range to match what we want.
     minmax = self.minmax
     img = util.linear_range_transform(img, from_range=(-1, 1), to_range=minmax)
     assert (img >= minmax[0]).all()
     assert (img <= minmax[1]).all()
     # print('output size: {}'.format(img.shape))
     return img
Example #6
0
    def forward(self, img):
        minmax = self.minmax
        # assert (img >= minmax[0]).all()
        # assert (img <= minmax[1]).all()
        # first normalize to [-1, 1]
        img = util.linear_range_transform(img, from_range=self.minmax, to_range=(-1.0, 1.0))

        out = self.model(img)
        out = out.view(out.shape[0], -1)
        # print(out.shape)
        validity = self.adv_layer(out)
        return validity
Example #7
0
    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)

        # tanh's range is [-1, 1]. Rescale the range to match what we want.
        minmax = self.minmax
        img = util.linear_range_transform(img, from_range=(-1, 1), to_range=minmax)
        assert (img >= minmax[0]).all()
        assert (img <= minmax[1]).all()
        # print('output size: {}'.format(img.shape))
        return img
Example #8
0
def numpy_image_to_float(img, out_range=(0, 1)):
    """
    Convert a numpy image (3d tensor, h x w x channels) into float format where
    the output range is as given in out_range.
    Return a numpy array.
    """
    # http://scikit-image.org/docs/dev/api/skimage.util.html#skimage.util.img_as_float
    new_img = skimage.util.img_as_float(img)
    if np.any(new_img < 0):
        # the range is (-1, 1)
        assert np.all(new_img <= 1)
        assert np.all(new_img >= -1)
        # transform to the specified out_range
        new_img = util.linear_range_transform(new_img, (-1, 1), out_range)
    else:
        # the range is (0, 1)
        assert np.all(new_img <= 1)
        assert np.all(new_img >= 0)
        # transform to the specified out_range
        new_img = util.linear_range_transform(new_img, (0, 1), out_range)

    return new_img