def _scat_data1(scat_img_dir, scat_out_dir):

            filename_list = os.listdir(scat_img_dir)  # read the directory files's name
            filename_list.sort()
            count = len(filename_list)
            if (self.cuda):
                scat = Scattering2D(J=scat_J, shape=img_shape).cuda()  # scattering transform
            else:
                scat = Scattering2D(J=scat_J, shape=img_shape)

            batch_image = []
            for count_idx in range(0, count):
                imgDir = os.path.join(scat_img_dir, os.path.basename(filename_list[count_idx]))
                img = np.float32((np.array(Image.open(imgDir)) / 127.5 - 1.0)).transpose(2, 0, 1)  # 读取彩色图像,3通道做散射操作
                # img = np.float16((np.array(Image.open(imgDir).convert('L'))/127.5 - 1.0))#灰度式式取取
                batch_image.append(img)
                if ((count_idx + 1) % batch_size == 0 or count_idx == count - 1):
                    print("In processing images: {}".format(count_idx + 1))

                    if (self.cuda):
                        batch_image = torch.from_numpy(np.array(batch_image)).cuda()
                        batch_scat = scat.forward(batch_image)
                        batch_scat = batch_scat.cpu()
                    else:
                        batch_image = torch.from_numpy(np.array(batch_image))
                        batch_scat = scat.forward(batch_image)

                    for c in range(len(batch_image)):
                        img_scat = batch_scat[c]
                        str1 = filename_list[c + (int(count_idx / batch_size)) * batch_size].split('.')
                        np.save(scat_out_dir + '/' + str1[0] + '.npy', img_scat)
                    batch_image = []

            print("Scattering transform over for {} -> {}".format(scat_img_dir, scat_out_dir))
            return
Example #2
0
    def __init__(self, *args, scattering=None):
        super(ScatterModel, self).__init__()

        self.shape = args
        self.bn1 = nn.BatchNorm2d(args[0])
        self.out = nn.Linear(args[0] * args[1] * args[2], 10)

        if scattering == None:
            scattering = Scattering2D(shape=(28, 28), J=2)
        self.scattering = scattering
    def __init__(self, in_channels, J, shape, max_order, L=8, k=1, alpha=None):
        super().__init__()

        self.scatNet = Scattering2D(J=J, shape=shape, max_order=max_order, L=L)

        channels_after_scat = calculate_channels_after_scat(
            in_channels=in_channels, order=max_order, J=J, L=L)

        if k > 1 and alpha is not None:
            raise ValueError("Only use alpha when k=1")

        # Create the learned mixing weights and possibly the expansion kernel
        self.A = nn.Parameter(
            torch.randn(channels_after_scat, channels_after_scat, k, k))
        self.b = nn.Parameter(torch.zeros(channels_after_scat, ))
        if alpha == 'impulse':
            self.alpha = nn.Parameter(random_postconv_impulse(
                channels_after_scat, channels_after_scat),
                                      requires_grad=False)
            self.pad = 1
        elif alpha == 'smooth':
            self.alpha = nn.Parameter(random_postconv_smooth(
                channels_after_scat, channels_after_scat, σ=1),
                                      requires_grad=False)
            self.pad = 1
        elif alpha == 'random':
            self.alpha = nn.Parameter(torch.randn(channels_after_scat,
                                                  channels_after_scat, 3, 3),
                                      requires_grad=False)
            init.xavier_uniform(self.alpha)
            self.pad = 1
        elif alpha is None:
            self.alpha = 1
            self.pad = (k - 1) // 2
        else:
            raise ValueError
Example #4
0
        description='CIFAR scattering  + hybrid examples')
    parser.add_argument('--mode',
                        type=int,
                        default=1,
                        help='scattering 1st or 2nd order')
    parser.add_argument('--width',
                        type=int,
                        default=2,
                        help='width factor for resnet')
    args = parser.parse_args()

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    if args.mode == 1:
        scattering = Scattering2D(J=2, shape=(32, 32), max_order=1)
        K = 17 * 3
    else:
        scattering = Scattering2D(J=2, shape=(32, 32))
        K = 81 * 3
    if use_cuda:
        scattering = scattering.cuda()

    model = Scattering2dResNet(K, args.width).to(device)

    # DataLoaders
    if use_cuda:
        num_workers = 4
        pin_memory = True
    else:
        num_workers = None
Example #5
0
plt.imshow(src_img)
plt.title("Original image")

src_img = np.moveaxis(src_img, -1, 0)  # HWC to CHW
max_iter = 5  # number of steps for the GD
print("Image shape: ", src_img.shape)
channels, height, width = src_img.shape

###############################################################################
#  Main loop
# ----------
for order in [1]:
    for J in [2, 4]:

        # Compute scattering coefficients
        scattering = Scattering2D(J=J, shape=(height, width), max_order=order)
        if device == "cuda":
            scattering = scattering.cuda()
            max_iter = 500
        src_img_tensor = torch.from_numpy(src_img).to(device).contiguous()
        scattering_coefficients = scattering(src_img_tensor)

        # Create trainable input image
        input_tensor = torch.rand(src_img.shape,
                                  requires_grad=True,
                                  device=device)

        # Optimizer hyperparams
        optimizer = optim.Adam([input_tensor], lr=1)

        # Training
Example #6
0
# print(x.shape)
# =============================================================================

B, C, W, H = 256, 3, 32, 32
epochs = 20
x = torch.randn(B,C,W,H, device=device)

# time for init Convolutional model 
t0 = time.time()
conv = torch.nn.Conv2d(C, 9*C, 3, padding=1, stride=2, bias=False)
conv.to(device)
t_model_conv = time.time()-t0

# time for initialing scattering model
t0 = time.time()
scatter = Scattering2D(J=1, shape=(W,H)).to(device)
t_model_scatter =time.time() - t0

print("init time, scatter: {}, conv: {} ".format(t_model_scatter, t_model_conv))

# Load data      
normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                     std=[0.2023, 0.1994, 0.2010])

train_loader = torch.utils.data.DataLoader(
  datasets.CIFAR10('root/', train=False, download=True,
                             transform=transforms.Compose([
                               transforms.ToTensor(),
                               normalize
                             ])),
  batch_size=B, shuffle=True, num_workers=2)
Example #7
0
        return x


##############################################Scattering Network#########################################################


class View(nn.Module):
    def __init__(self, *args):
        super(View, self).__init__()
        self.shape = args

    def forward(self, x):
        return x.view(-1, *self.shape)


scattering = Scattering2D(shape=(28, 28), J=2)

K = 81
model = nn.Sequential(View(K, 7, 7), nn.BatchNorm2d(K), View(K * 7 * 7),
                      nn.Linear(K * 7 * 7, 10))


class ScatterModel(nn.Module):
    def __init__(self, *args, scattering=None):
        super(ScatterModel, self).__init__()

        self.shape = args
        self.bn1 = nn.BatchNorm2d(args[0])
        self.out = nn.Linear(args[0] * args[1] * args[2], 10)

        if scattering == None:
Example #8
0
from matplotlib import pyplot as plt
import time

from kymatio.torch import Scattering2D
device = "cuda:0"

t0 = time.time()

#img = Image.open('../trump.jpg')
img = Image.open('../chris.jpeg')
img_tensor = Fv.to_tensor(img)
x = img_tensor[None]
x = x.cuda()

B, C, W, H = x.shape
model = Scattering2D(J=2, shape=(W, H), L=8)
model.to(device)
model.to(device)
y = model(x)
y = y.view(y.size(0), -1, y.size(3), y.size(4))

for i in range(243):
    name = "out/order2/scatter/chris_scatt_{}.png".format(i)
    #torchvision.utils.save_image(y[0,i:i+1].detach().cpu(), name)
    plt.imshow(y[0, i].detach().cpu().numpy())
    plt.savefig(name)
    plt.show()

tFinnish = time.time() - t0
print("time: ", tFinnish)
Example #9
0
                        smpls = stimgs[i, j,
                                       l] / (stimgs[i, j, 2 *
                                                    (l - L * J - 1) //
                                                    (L * (J - 1)) + 1] + 1e-16)
                else:
                    if l == 0:
                        bns = np.arange(NBINS + 1) / NBINS
                    else:
                        bns = np.arange(NBINS + 1) / (NBINS * 16)
                    smpls = stimgs[i, j, l]
                h, _ = np.histogram(smpls, bns, range=(bns[0], bns[-1]))
                h[-1] += (smpls > bns[-1]).sum()
                features[i, j, l] = h / h.sum()
    return features


st = Scattering2D(4, (288, 352), L=4).cuda()

for dt in ['alpha', 'beta', 'gamma']:
    print('Processing split', dt)
    features = []
    labels = []
    for i in tqdm.tqdm(range(len(glob.glob('data/dyntex_' + dt + '/*')))):
        files = glob.glob('data/dyntex_' + dt + '/c' + str(i + 1) + '_*/*.avi')
        for f in files:
            labels.append(i)
            # Normalized Scattering Transform:
            np.save(f[:-4] + '_nst.npy', readvid_gr_features(f, st, (4, 4)))
            # Regular Scattering Transform
            np.save(f[:-4] + '_st.npy', readvid_gr_features(f, st))
def get_scatterNet(processing_arguments, setup_arguments):

    scat_type = setup_arguments.shared_arguments.scat_type
    J = setup_arguments.J
    scat_order = setup_arguments.scat_order
    in_shape = setup_arguments.in_shape
    in_channels = setup_arguments.in_channels

    scat_post_avpool_kernel_size = None
    scat_post_avpool_stride = 1
    scat_post_avpool_padding = 0

    scatNet = None

    dtcwt_fam = ["dtcwt", "dtcwt_l"]
    malalt_fam = ["mallat", "mallat_l"]

    if setup_arguments.shared_arguments.enable_scat == 1:
        if scat_type == 'mallat':
            scatNet = Scattering2D(J=J,
                                   shape=(in_shape, in_shape),
                                   max_order=scat_order)

        elif scat_type == 'mallat_l':
            scatNet = Scattering2DMixed(in_channels=in_channels,
                                        J=J,
                                        shape=(in_shape, in_shape),
                                        max_order=scat_order,
                                        L=8,
                                        k=1,
                                        alpha=None)

        elif scat_type == 'dtcwt':
            if J == 1:
                if scat_order == 1:
                    scatNet = nn.Sequential(
                        OrderedDict([('order1', ScatLayerj1(2))]))
                    scat_post_avpool_kernel_size = in_shape / 4 + 1  #use half of scat output +1 to half the feature resolution
                elif scat_order == 2:
                    scatNet = nn.Sequential(
                        OrderedDict([('order1', ScatLayerj1(2)),
                                     ('order2', ScatLayerj1(2))]))
                    scat_post_avpool_kernel_size = in_shape / 8 + 1  #use half of scat output +1 to half the feature resolution
                else:
                    raise ValueError(
                        "Scattering order of 1 and 2 only available for this implementation of DTWCT"
                    )
            else:
                raise ValueError(
                    "J can only be 1 in the current implementation of DTWCT")

        elif scat_type == 'dtcwt_l':
            if J == 1:
                if scat_order == 1:
                    scatNet = nn.Sequential(
                        OrderedDict([('order1', InvariantLayerj1(in_channels))
                                     ]))
                    scat_post_avpool_kernel_size = in_shape / 4 + 1  # use half of scat output +1 to half the feature resolution
                elif scat_order == 2:
                    scatNet = nn.Sequential(
                        OrderedDict([('order1', InvariantLayerj1(in_channels)),
                                     ('order2',
                                      InvariantLayerj1(7 * in_channels))]))
                    scat_post_avpool_kernel_size = in_shape / 8 + 1  # use half of scat output +1 to half the feature resolution
            else:
                raise ValueError(
                    "J can only be 1 in the current implementation of DTWCT")

        if scat_type in dtcwt_fam \
            and scat_post_avpool_kernel_size is not None\
            and setup_arguments.half_scat_feat_resolution:

            scat_post_avpool_kernel_size = int(scat_post_avpool_kernel_size)
            scatNet = nn.Sequential(
                OrderedDict([('scatnet', scatNet),
                             ('post_scat_pool',
                              nn.AvgPool2d(scat_post_avpool_kernel_size,
                                           stride=scat_post_avpool_stride,
                                           padding=scat_post_avpool_padding))
                             ]))

    processing_arguments.scatNet = scatNet
Example #11
0
def get_scatter_transform(dataset):
    shape = SHAPES[dataset]
    scattering = Scattering2D(J=2, shape=shape[:2])
    K = 81 * shape[2]
    (h, w) = shape[:2]
    return scattering, K, (h // 4, w // 4)