def regularisation_term(x, uniform_grid):
    """
    :param x: flattened grid
    :param uniform_grid: simple grid
    :return: regularisation term
    """
    x = x.reshape(uniform_grid.shape)

    down = np.pad(np.arange(1, x.shape[0]), (0, 1), mode='edge')
    up = np.pad(np.arange(0, x.shape[0] - 1), (1, 0), mode='edge')
    right = np.pad(np.arange(1, x.shape[1]), (0, 1), mode='edge')
    left = np.pad(np.arange(0, x.shape[1] - 1), (1, 0), mode='edge')

    return np.sum((x - x[down, :]) ** 2) + np.sum((x - x[up, :]) ** 2) \
           + np.sum((x - x[:, right]) ** 2) + np.sum((x - x[:, left]) ** 2)
Beispiel #2
0
def _pad(arr, newshape, axes=None, mode="constant", constant_values=0):
    """Pad an array to fit into newshape

    Pad `arr` with zeros to fit into newshape,
    which uses the `np.fft.fftshift` convention of moving
    the center pixel of `arr` (if `arr.shape` is odd) to
    the center-right pixel in an even shaped `newshape`.
    """
    if axes is None:
        newshape = np.asarray(newshape)
        currshape = np.array(arr.shape)
        dS = newshape - currshape
        startind = (dS + 1) // 2
        endind = dS - startind
        pad_width = list(zip(startind, endind))
    else:
        # only pad the axes that will be transformed
        pad_width = [(0, 0) for axis in arr.shape]
        try:
            len(axes)
        except TypeError:
            axes = [axes]
        for a, axis in enumerate(axes):
            dS = newshape[a] - arr.shape[axis]
            startind = (dS + 1) // 2
            endind = dS - startind
            pad_width[axis] = (startind, endind)
    if mode == "constant" and constant_values == 0:
        result = fast_zero_pad(arr, pad_width)
    else:
        result = np.pad(arr, pad_width, mode=mode)
    return result
Beispiel #3
0
def image_to_column(images, filter_shape, stride, padding):
    """Rearrange image blocks into columns.

    Parameters
    ----------

    filter_shape : tuple(height, width)
    images : np.array, shape (n_images, n_channels, height, width)
    padding: tuple(height, width)
    stride : tuple (height, width)

    """
    n_images, n_channels, height, width = images.shape
    f_height, f_width = filter_shape
    out_height, out_width = convoltuion_shape(height, width,
                                              (f_height, f_width), stride,
                                              padding)
    images = np.pad(images, ((0, 0), (0, 0), padding, padding),
                    mode='constant')

    col = np.zeros(
        (n_images, n_channels, f_height, f_width, out_height, out_width))
    for y in range(f_height):
        y_bound = y + stride[0] * out_height
        for x in range(f_width):
            x_bound = x + stride[1] * out_width
            col[:, :, y, x, :, :] = images[:, :, y:y_bound:stride[0],
                                           x:x_bound:stride[1]]

    col = col.transpose(0, 4, 5, 1, 2,
                        3).reshape(n_images * out_height * out_width, -1)
    return col
def nll_GLM_GanmorCalciumAR1(w, X, Y, hyperparams, nlfun, S=10):
    """
    Negative log-likelihood for a GLM with Ganmor AR1 mixture model for calcium imaging data.

    Input:
        w:              [D x 1]  vector of GLM regression weights
        X:              [T x D]  design matrix
        Y:              [T x 1]  calcium fluorescence observations
        hyperparams:    [3 x 1]  model hyperparameters: log tau, log alpha, log Gaussian variance
        nlfun:          [func]   function handle for nonlinearity
        S:              [scalar] number of spikes to marginalize
        return_hess:    [bool]   flag for returning Hessian

    Output:
        negative log-likelihood, gradient, and Hessian
    """

    # unpack hyperparams
    tau, alpha, sig2 = hyperparams

    # compute AR(1) diffs
    taudecay = np.exp(-1.0 / tau)  # decay factor for one time bin
    Y = np.pad(Y, (1, 0))  # pad Y by a time bin
    Ydff = (Y[1:] - taudecay * Y[:-1]) / alpha

    # compute grid of spike counts
    ygrid = np.arange(0, S + 1)

    # Gaussian log-likelihood terms
    log_gauss_grid = -0.5 * (Ydff[:, None] - ygrid[None, :])**2 / (
        sig2 / alpha**2) - 0.5 * np.log(2.0 * np.pi * sig2)

    Xproj = X @ w
    poissConst = gammaln(ygrid + 1)

    # compute neglogli, gradient, and (optionally) Hessian
    f, logf, df, ddf = nlfun(Xproj)
    logPcounts = logf[:, None] * ygrid[None, :] - f[:,
                                                    None] - poissConst[None, :]

    # compute log-likelihood for each time bin
    logjoint = log_gauss_grid + logPcounts
    logli = logsumexp(logjoint, axis=1)  # log likelihood for each time bin
    negL = -np.sum(logli)  # negative log likelihood

    # gradient
    dLpoiss = (df / f)[:, None] * ygrid[
        None, :] - df[:, None]  # deriv of Poisson log likelihood
    gwts = np.sum(np.exp(logjoint - logli[:, None]) * dLpoiss,
                  axis=1)  # gradient weights
    gradient = -X.T @ gwts

    # Hessian
    ddLpoiss = (ddf / f - (df / f)**2)[:, None] * ygrid[None, :] - ddf[:, None]
    ddL = (ddLpoiss + dLpoiss**2)
    hwts = np.sum(np.exp(logjoint - logli[:, None]) * ddL,
                  axis=1) - gwts**2  # hessian weights
    H = -X.T @ (X * hwts[:, None])

    return negL, gradient, H
Beispiel #5
0
    def apply(self, x):
        f_n, f_c, f_d, f_h, f_w = self.filters.shape
        f = self.filters.reshape(f_n, -1)

        b = self.bias.reshape(f_n, -1)

        p = self.padding
        x_pad = np.pad(x, ((0, 0), (0, 0), (p, p), (p, p), (p, p)),
                       mode='constant')
        x_n, x_c, x_d, x_h, x_w = x_pad.shape

        res_d = int((x_d - f_d) / self.stride) + 1
        res_h = int((x_h - f_h) / self.stride) + 1
        res_w = int((x_w - f_w) / self.stride) + 1
        size = f_c * f_d * f_h * f_w

        c_idx, d_idx, h_idx, w_idx = index3d(x_c, self.stride, (f_d, f_h, f_w),
                                             (x_d, x_h, x_w))

        res = x_pad[:, c_idx, d_idx, h_idx, w_idx]
        res = res.reshape(size, -1)

        res = f @ res + b
        res = res.reshape(1, f_n, res_d, res_h, res_w)

        return res
Beispiel #6
0
def image_to_column(images, filter_shape, stride, padding):
    """Rearrange image blocks into columns.

    Parameters
    ----------

    filter_shape : tuple(height, width)
    images : np.array, shape (n_images, n_channels, height, width)
    padding: tuple(height, width)
    stride : tuple (height, width)

    """
    n_images, n_channels, height, width = images.shape
    f_height, f_width = filter_shape
    out_height, out_width = convoltuion_shape(height, width, (f_height, f_width), stride, padding)
    images = np.pad(images, ((0, 0), (0, 0), padding, padding), mode='constant')

    col = np.zeros((n_images, n_channels, f_height, f_width, out_height, out_width))
    for y in range(f_height):
        y_bound = y + stride[0] * out_height
        for x in range(f_width):
            x_bound = x + stride[1] * out_width
            col[:, :, y, x, :, :] = images[:, :, y:y_bound:stride[0], x:x_bound:stride[1]]

    col = col.transpose(0, 4, 5, 1, 2, 3).reshape(n_images * out_height * out_width, -1)
    return col
Beispiel #7
0
 def _pad_cube(self, cube):
     if self.bbox is not None:
         padded = np.pad(cube,
                         self.pad_width,
                         mode="constant",
                         constant_values=0)
         return padded[self.slices]
     return cube
Beispiel #8
0
 def _pad_morph(self, morph):
     if self.bbox is not None:
         padded = np.pad(morph,
                         self.pad_width[1:],
                         mode="constant",
                         constant_values=0)
         return padded[self.slices[1:]]
     return morph
Beispiel #9
0
 def _pad_sed(self, sed):
     if self.bbox is not None:
         padded = np.pad(sed,
                         self.pad_width[0],
                         mode="constant",
                         constant_values=0)
         return padded[self.slices[0]]
     else:
         return sed
Beispiel #10
0
 def _convolve_band(self, model, psf_fft):
     """Convolve the model in a single band
     """
     _model = np.pad(model, self.image_padding, 'constant')
     model_fft = np.fft.fft2(np.fft.ifftshift(_model))
     convolved_fft = model_fft * psf_fft
     convolved = np.fft.ifft2(convolved_fft)
     result = np.fft.fftshift(np.real(convolved))
     (bottom, top), (left, right) = self.image_padding
     result = result[bottom:-top, left:-right]
     return result
Beispiel #11
0
def nll_GLM_GanmorCalciumAR1(w, X, Y, hyperparams, nlfun, S=10):
    """
    Negative log-likelihood for a GLM with Ganmor AR1 mixture model for calcium imaging data.

    Input:
        w:              [D x 1]  vector of GLM regression weights
        X:              [T x D]  design matrix
        Y:              [T x 1]  calcium fluorescence observations
        hyperparams:    [3 x 1]  model hyperparameters: log tau, log alpha, log Gaussian variance
        nlfun:          [func]   function handle for nonlinearity
        S:              [scalar] number of spikes to marginalize
        return_hess:    [bool]   flag for returning Hessian

    Output:
        negative log-likelihood, gradient, and Hessian
    """
    # unpack hyperparams
    ar_coefs, log_alpha, log_sig2 = hyperparams
    alpha = np.exp(log_alpha)
    sig2 = np.exp(log_sig2)
    p = ar_coefs.shape[0]  # AR(p)
    # compute AR(p) diffs
    Ydecay = np.zeros_like(Y)
    Y = np.pad(Y, (p, 0))  # pad Y by p time bins
    for i, ai in enumerate(ar_coefs):
        Ydecay = Ydecay + ai * Y[p - 1 - i:-1 - i]
    # Ydecay2 = ar_coefs[0] * Y[1:-1]
    # Ydecay2 = Ydecay2 + ar_coefs[1] * Y[:-2]
    # print(np.linalg.norm(Ydecay - Ydecay2))
    # import ipdb; ipdb.set_trace()
    Ydff = (Y[p:] - Ydecay) / alpha

    # compute grid of spike counts
    ygrid = np.arange(0, S + 1)

    # Gaussian log-likelihood terms
    log_gauss_grid = -0.5 * (Ydff[:, None] - ygrid[None, :])**2 / (
        sig2 / alpha**2) - 0.5 * np.log(2.0 * np.pi * sig2)

    Xproj = X @ w
    poissConst = gammaln(ygrid + 1)

    # compute neglogli, gradient, and (optionally) Hessian
    f, logf, df, ddf = nlfun(Xproj)
    logPcounts = logf[:, None] * ygrid[None, :] - f[:,
                                                    None] - poissConst[None, :]

    # compute log-likelihood for each time bin
    logjoint = log_gauss_grid + logPcounts
    logli = logsumexp(logjoint, axis=1)  # log likelihood for each time bin
    negL = -np.sum(logli)  # negative log likelihood

    return negL
Beispiel #12
0
    def match(self, scene):

        # 1) determine shape of scene in obs, set mask

        # 2) compute the interpolation kernel between scene and obs

        # 3) compute obs.psf in the frame of scene, store in Fourier space
        # A few notes on this procedure:
        # a) This assumes that scene.psfs and self.psfs have the same spatial shape,
        #    which will need to be modified for multi-resolution datasets
        if self._psfs is not None:
            ipad, ppad = interpolation.get_common_padding(self.images,
                                                          self._psfs,
                                                          padding=self.padding)
            self.image_padding, self.psf_padding = ipad, ppad
            _psfs = np.pad(self._psfs, ((0, 0), *self.psf_padding), 'constant')
            _target = np.pad(scene._psfs, self.psf_padding, 'constant')

            new_kernel_fft = []
            # Deconvolve the target PSF
            target_fft = np.fft.fft2(np.fft.ifftshift(_target))

            for _psf in _psfs:
                observed_fft = np.fft.fft2(np.fft.ifftshift(_psf))
                # Create the matching kernel
                kernel_fft = observed_fft / target_fft
                # Take the inverse Fourier transform to normalize the result
                # Trials without this operation are slow to converge, but in the future
                # we may be able to come up with a method to normalize in the Fourier Transform
                # and avoid this step.
                kernel = np.fft.ifft2(kernel_fft)
                kernel = np.fft.fftshift(np.real(kernel))
                kernel /= kernel.sum()
                # Store the Fourier transform of the matching kernel
                new_kernel_fft.append(np.fft.fft2(np.fft.ifftshift(kernel)))
            self.psfs_fft = np.array(new_kernel_fft)

        return self
Beispiel #13
0
def line_bending_term(x, uniform_grid):
    """
    :param x: flattened grid
    :param uniform_grid: simple grid
    :return: line bending term
    """
    x = x.reshape(uniform_grid.shape)

    down = np.pad(np.arange(1, x.shape[0]), (0, 1), mode='edge')
    up = np.pad(np.arange(0, x.shape[0] - 1), (1, 0), mode='edge')
    right = np.pad(np.arange(1, x.shape[1]), (0, 1), mode='edge')
    left = np.pad(np.arange(0, x.shape[1] - 1), (1, 0), mode='edge')

    return np.sum(((x - x[down, :])[:, :, 0] * normalized(uniform_grid - uniform_grid[down, :], axis=2)[:, :, 1] -
                   (x - x[down, :])[:, :, 1] * normalized(uniform_grid - uniform_grid[down, :], axis=2)[:, :,
                                               0]) ** 2) + \
           np.sum(((x - x[up, :])[:, :, 0] * normalized(uniform_grid - uniform_grid[up, :], axis=2)[:, :, 1] -
                   (x - x[up, :])[:, :, 1] * normalized(uniform_grid - uniform_grid[up, :], axis=2)[:, :, 0]) ** 2) + \
           np.sum(((x - x[:, right])[:, :, 0] * normalized(uniform_grid - uniform_grid[:, right], axis=2)[:, :, 1] -
                   (x - x[:, right])[:, :, 1] * normalized(uniform_grid - uniform_grid[:, right], axis=2)[:, :,
                                                0]) ** 2) + \
           np.sum(((x - x[:, left])[:, :, 0] * normalized(uniform_grid - uniform_grid[:, left], axis=2)[:, :, 1] -
                   (x - x[:, left])[:, :, 1] * normalized(uniform_grid - uniform_grid[:, left], axis=2)[:, :, 0]) ** 2)
Beispiel #14
0
    def get_anchor_points(self, actions):
        """ Builds anchor action point sets for the direct estimator

        Args:
            actions (np.array): actions drawn from the logging policy to build quantiles on
        """
        if self.mode == 'quantile':
            self.quantiles = np.quantile(actions, np.linspace(0, 1, self.K+1))
            self.action_set = np.pad(self.quantiles, 1, 'constant', constant_values=(self.eps, np.inf))
        elif self.mode == 'grid':
            self.action_set = np.arange(self.eps, np.max(actions) + self.stride, self.stride)
            self.K = int(np.max(actions)/self.stride)
        self.initialized = True
        return self.action_set
Beispiel #15
0
def get_submatrix_add(lst_matrix, center_pt_tuple, convolution):
    np_matrix = np.array(lst_matrix)
    r, c = np_matrix.shape
    lt_padding = 0 + (center_pt_tuple[0] - 1)
    rt_padding = (c - 1) - (center_pt_tuple[0] + 1)
    top_padding = 0 + (center_pt_tuple[1] - 1)
    btm_padding = (r - 1) - (center_pt_tuple[1] + 1)
    row_start = 0
    row_end = np.array(convolution).shape[0]
    col_start = 0
    col_end = np.array(convolution).shape[1]

    if lt_padding < 0:
        lt_padding = 0
        row_start = row_start + 1
    if rt_padding < 0:
        rt_padding = 0
        row_end = row_end - 1
    if top_padding < 0:
        top_padding = 0
        col_start = col_start + 1
    if btm_padding < 0:
        btm_padding = 0
        col_end = col_end - 1

    padded_convo = np.pad(np.array(convolution)[row_start:row_end,
                                                col_start:col_end],
                          ((top_padding, btm_padding),
                           (lt_padding, rt_padding)),
                          mode='constant',
                          constant_values=(0, 0))

    try:
        new_matrix = np_matrix + padded_convo
    except Exception as e:
        print(e)
        print('left pad: ', lt_padding)
        print('right pad: ', rt_padding)
        print('top pad: ', top_padding)
        print('btm pad: ', btm_padding)
        print('r start:', row_start)
        print('r end:', row_end)
        print('c start:', col_start)
        print('c end:', col_end)
        print(np.array(convolution)[row_start:row_end, col_start:col_end])
        print(
            np.array(convolution)[row_start:row_end, col_start:col_end].shape)

    return new_matrix
Beispiel #16
0
def image_to_column(images, filter_shape, stride, padding):
    n_images, n_channels, height, width = images.shape
    f_height, f_width = filter_shape
    out_height, out_width = convolution_shape(height, width, (f_height, f_width), stride, padding)
    images = np.pad(images, ((0, 0), (0, 0), padding, padding), mode="constant")

    col = np.zeros((n_images, n_channels, f_height, f_width, out_height, out_width))
    for y in range(f_height):
        y_bound = y + stride[0] * out_height
        for x in range(f_width):
            x_bound = x + stride[1] * out_width
            col[:, :, y, x, :, :] = images[:, :, y : y_bound : stride[0], x : x_bound : stride[1]]

    col = col.transpose(0, 4, 5, 1, 2, 3).reshape(n_images * out_height * out_width, -1)
    return col
def calculate_forces(x_phys, args):
    applied_force = args['forces']
    
    if not args.get('g'):
        return applied_force
    
    density = 0
    for pad_left in [0, 1]:
        for pad_up in [0, 1]:
            padding = [(pad_left, 1 - pad_left), (pad_up, 1 - pad_up)]
            density += (1 / 4) * np.pad(
                x_phys.T, padding, mode='constant', constant_values=0
            )
    gravitional_force = -args['g'] * density[..., np.newaxis] * np.array([0, 1])
    return applied_force + gravitional_force.ravel()
    def _setup(self):
        """ Setup the experiments and creates the data
        """
        # Actions
        features, y = self.get_X_y_by_name()
        potentials = self._get_potentials(y)
        actions = self.rng.lognormal(mean=self.start_mu,
                                     sigma=self.start_sigma,
                                     size=potentials.shape[0])
        rewards = self.get_rewards_from_actions(potentials, actions)
        if self.discrete:
            from scipy.stats import lognorm
            rv = lognorm(s=self.start_sigma, scale=np.exp(self.start_mu))
            quantiles = np.quantile(actions,
                                    np.linspace(0, 1, self.discrete + 1))
            action_anchors = np.pad(quantiles,
                                    1,
                                    'constant',
                                    constant_values=(1e-7, np.inf))
            bins = action_anchors[:-1]
            inds = np.digitize(actions, bins, right=True)
            inds_1 = inds - 1
            inds_1[inds_1 == -1] = 0
            pi_logging = rv.cdf(bins[inds]) - rv.cdf(bins[inds_1])
        else:
            pi_logging = Dataset.logging_policy(actions, self.start_mu,
                                                self.start_sigma)

        # Test train split
        self.actions_train, self.actions_test, self.features_train, self.features_test, self.reward_train, \
        self.reward_test, self.pi_0_train, self.pi_0_test, self.potentials_train, self.potentials_test, \
        self.l_train, self.l_test = train_test_split(actions, features, rewards, pi_logging, potentials, y,
                                                     train_size=self.train_size, random_state=42)

        self.actions_train, self.actions_valid, self.features_train, self.features_valid, self.reward_train, \
        self.reward_valid, self.pi_0_train, self.pi_0_valid, self.potentials_train, self.potentials_valid, \
        self.l_train, self.l_valid = train_test_split(self.actions_train, self.features_train, self.reward_train,
                                                      self.pi_0_train, self.potentials_train, self.l_train,
                                                     train_size=self.val_size, random_state=42)

        min_max_scaler = MinMaxScaler(feature_range=(0, 1))
        self.features_train = min_max_scaler.fit_transform(self.features_train)
        self.features_valid = min_max_scaler.transform(self.features_valid)
        self.features_test = min_max_scaler.transform(self.features_test)

        self.baseline_reward_valid = np.mean(self.reward_valid)
        self.baseline_reward_test = np.mean(self.reward_test)
Beispiel #19
0
    def apply(self, x):
        k_l = self.kernel

        p = self.padding
        x_pad = np.pad(x, ((0, 0), (0, 0), (p, p)), mode='constant')
        x_n, x_c, x_l = x_pad.shape

        res_l = int((x_l - k_l) / self.stride) + 1

        c_idx, l_idx = index1d(x_c, self.stride, self.kernel, (x_l))

        res = x_pad[:, c_idx, l_idx]
        res = res.reshape(x_c, k_l, -1)

        res = np.max(res, axis=1)
        res = res.reshape(1, x_c, res_l)

        return res
Beispiel #20
0
    def print_images(self,
                     graph_name='Vasc_Graph.png',
                     img_name='Vasc2D_img.png'):
        fig = plt.figure()
        for j, s in enumerate(self.tri.simplices):
            p = np.array(self.pts)[s].mean(axis=0)
            plt.text(p[0], p[1], 'Cell #%d' % j,
                     ha='center')  # label triangles
        plt.triplot(
            np.array(self.pts)[:, 0],
            np.array(self.pts)[:, 1], self.tri.simplices)
        plt.plot(np.array(self.pts)[:, 0], np.array(self.pts)[:, 1], 'o')
        fig.savefig(graph_name)
        plt.close(fig)

        # https://stackoverflow.com/questions/38191855/zero-pad-numpy-array
        img = np.pad(np.array(self.img), ((2, 3), (2, 3)), 'constant')
        plt.imsave(img_name, np.rot90(img), cmap='jet')
Beispiel #21
0
def pad(var,
        pad_len,
        mode='constant',
        constant_values=0,
        override_backend=None):
    """
    :param pad_len: A tuple of tuples. Consistent with the format of numpy.pad.
    :param mode: Choose from 'constant', 'reflect'.
    """
    bn = override_backend if override_backend is not None else global_settings.backend
    args = {}
    mode_dict = {
        'constant': {
            'autograd': 'constant',
            'pytorch': 'constant'
        },
        'edge': {
            'autograd': 'edge',
            'pytorch': 'replicate'
        },
        'reflect': {
            'autograd': 'reflect',
            'pytorch': 'reflect'
        },
        'wrap': {
            'autograd': 'wrap',
            'pytorch': 'circular'
        }
    }
    if mode == 'constant':
        args['constant_values'] = 0
    if bn == 'autograd':
        return anp.pad(var, pad_len, mode=mode_dict[mode][bn], **args)
    elif bn == 'pytorch':
        pad_len = [x for y in pad_len[::-1] for x in y]
        return tc.nn.functional.pad(var,
                                    pad_len,
                                    mode=mode_dict[mode][bn],
                                    value=constant_values)
    elif bn == 'numpy':
        return np.pad(var, pad_len, mode=mode, **args)
Beispiel #22
0
    def apply(self, x):
        k_h, k_w = self.kernel

        p = self.padding
        x_pad = np.pad(x, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant')
        x_n, x_c, x_h, x_w = x_pad.shape

        res_h = int((x_h - k_h) / self.stride) + 1
        res_w = int((x_w - k_w) / self.stride) + 1

        c_idx, h_idx, w_idx = index2d(x_c, self.stride, self.kernel,
                                      (x_h, x_w))

        res = x_pad[:, c_idx, h_idx, w_idx]

        res = res.reshape(x_c, k_h * k_w, -1)

        res = np.max(res, axis=1)
        res = res.reshape(1, x_c, res_h, res_w)

        return res
Beispiel #23
0
def admixture_operator(n_node, p):
    # axis0=n_from_parent, axis1=der_from_parent, axis2=der_in_parent
    der_in_parent = np.tile(np.arange(n_node + 1), (n_node + 1, n_node + 1, 1))
    n_from_parent = np.transpose(der_in_parent, [2, 0, 1])
    der_from_parent = np.transpose(der_in_parent, [0, 2, 1])

    anc_in_parent = n_node - der_in_parent
    anc_from_parent = n_from_parent - der_from_parent

    x = comb(der_in_parent, der_from_parent) * comb(
        anc_in_parent, anc_from_parent) / comb(n_node, n_from_parent)
    # rearrange so axis0=1, axis1=der_in_parent, axis2=der_from_parent, axis3=n_from_parent
    x = np.transpose(x)
    x = np.reshape(x, [1] + list(x.shape))

    n = np.arange(n_node + 1)
    B = comb(n_node, n)

    # the two arrays to convolve_sum_axes
    x1 = (x * B * ((1 - p)**n) * (p**(n[::-1])))
    x2 = x[:, :, :, ::-1]

    # reduce array size; approximate low probability events with 0
    mu = n_node * (1 - p)
    sigma = np.sqrt(n_node * p * (1 - p))
    n_sd = 4
    lower = np.max((0, np.floor(mu - n_sd * sigma)))
    upper = np.min((n_node, np.ceil(mu + n_sd * sigma))) + 1
    lower, upper = int(lower), int(upper)

    ##x1 = x1[:,:,:upper,lower:upper]
    ##x2 = x2[:,:,:(n_node-lower+1),lower:upper]

    ret = convolve_sum_axes(x1, x2)
    # axis0=der_in_parent1, axis1=der_in_parent2, axis2=der_in_child
    ret = np.reshape(ret, ret.shape[1:])
    if ret.shape[2] < (n_node + 1):
        ret = np.pad(ret, [(0, 0), (0, 0), (0, n_node + 1 - ret.shape[2])],
                     "constant")
    return ret[:, :, :(n_node + 1)]
Beispiel #24
0
def pad(var, pad_len, mode='constant', constant_values=0, backend='autograd'):
    """
    Pad array.
    [ATTENTION: The behavior of this function is different between Autograd and Pytorch backend.]

    :param pad_len: A tuple of tuples. Consistent with the format of numpy.pad.
    :param mode: Choose from 'constant', 'reflect'.
    """
    args = {}
    mode_dict = {'constant': {'autograd': 'constant', 'pytorch': 'constant'},
                 'edge':    {'autograd': 'edge',    'pytorch': 'replicate'},
                 'reflect': {'autograd': 'reflect', 'pytorch': 'reflect'},
                 'wrap':    {'autograd': 'wrap',    'pytorch': 'circular'}}
    if mode == 'constant':
        args['constant_values'] = 0
    if backend == 'autograd':
        return anp.pad(var, pad_len, mode=mode_dict[mode][backend], **args)
    elif backend == 'pytorch':
        pad_len = [x for y in pad_len[::-1] for x in y]
        return tc.nn.functional.pad(var, pad_len, mode=mode_dict[mode][backend], value=constant_values)
    elif backend == 'numpy':
        return np.pad(var, pad_len, mode=mode, **args)
Beispiel #25
0
    def apply(self, x):
        f_n, f_c, f_l = self.filters.shape  # 2, 3, 4
        f = self.filters.reshape(f_n, -1)  # 2, 12

        b = self.bias.reshape(f_n, -1)  # 2, 1

        p = self.padding
        x_pad = np.pad(x, ((0, 0), (0, 0), (p, p)), mode='constant')
        x_n, x_c, x_l = x_pad.shape  # 1, 3, 10

        res_l = int((x_l - f_l) / self.stride) + 1
        size = f_c * f_l

        c_idx, l_idx = index1d(x_c, self.stride, (f_l), (x_l))

        res = x_pad[:, c_idx, l_idx]  # 1, 12, 10
        res = res.reshape(size, -1)  # 12, 10

        res = f @ res + b  # 2, 10
        res = res.reshape(1, f_n, res_l)  # 1, 2, 10

        return res
Beispiel #26
0
def multislice_propagate_cnn(grid_delta,
                             grid_beta,
                             probe_real,
                             probe_imag,
                             energy_ev,
                             psize_cm,
                             kernel_size=17,
                             free_prop_cm=None,
                             debug=False):

    assert kernel_size % 2 == 1, 'kernel_size must be an odd number.'
    n_batch, shape_y, shape_x, n_slice = grid_delta.shape
    lmbda_nm = 1240. / energy_ev
    voxel_nm = np.array(psize_cm) * 1.e7
    delta_nm = voxel_nm[-1]
    k = 2. * np.pi * delta_nm / lmbda_nm
    grid_shape = np.array(grid_delta.shape[1:])
    size_nm = voxel_nm * grid_shape
    mean_voxel_nm = np.prod(voxel_nm)**(1. / 3)

    # print('Critical distance is {} cm.'.format(psize_cm[0] * psize_cm[1] * grid_delta.shape[1] / (lmbda_nm * 1e-7)))

    if kernel_size % 2 == 0:
        warnings.warn('Kernel size should be odd.')
    # kernel = get_kernel(delta_nm, lmbda_nm, voxel_nm, np.array(grid_delta.shape[1:]))
    kernel = get_kernel(delta_nm, lmbda_nm, voxel_nm, grid_shape - 1)
    kernel = np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(kernel)))
    # dxchange.write_tiff(np.abs(kernel), 'test/kernel_abs', dtype='float32')
    # dxchange.write_tiff(np.angle(kernel), 'test/kernel_phase', dtype='float32')
    # raise Exception

    kernel_mid = ((np.array(kernel.shape) - 1) / 2).astype('int')
    half_kernel_size = int((kernel_size - 1) / 2)
    kernel = kernel[kernel_mid[0] - half_kernel_size:kernel_mid[0] +
                    half_kernel_size + 1, kernel_mid[1] -
                    half_kernel_size:kernel_mid[1] + half_kernel_size + 1]
    # kernel = get_kernel_ir_real(delta_nm, lmbda_nm, voxel_nm, [kernel_size, kernel_size, 256])
    # kernel /= kernel.size
    pad_len = (kernel_size - 1) // 2

    # probe_real = np.pad(probe_real, [[pad_len, pad_len], [pad_len, pad_len]], mode='constant', constant_values=1.0)
    # probe_imag = np.pad(probe_real, [[pad_len, pad_len], [pad_len, pad_len]], mode='constant', constant_values=0)
    probe = probe_real + 1j * probe_imag
    probe_size = probe.shape
    probe = np.tile(probe, [n_batch, 1, 1])

    # grid_delta = np.pad(grid_delta, [[0, 0], [pad_len, pad_len], [pad_len, pad_len], [0, 0]], mode='constant', constant_values=0)
    # grid_beta = np.pad(grid_beta, [[0, 0], [pad_len, pad_len], [pad_len, pad_len], [0, 0]], mode='constant', constant_values=0)

    probe_array = []

    # Build cyclic convolution matrix for kernel
    # kernel_mat = np.zeros([np.prod(probe_size)] * 2)
    # kernel_full_00 = np.zeros(probe_size)
    # kernel_full_00[:kernel_size, :kernel_size] = kernel
    # kernel_full_00 = np.roll(kernel_full_00, -half_kernel_size, axis=0)
    # kernel_full_00 = np.roll(kernel_full_00, -half_kernel_size, axis=1)
    # kernel_mat[0, :] = kernel_full_00.flatten()
    # for i in trange(probe_size[0]):
    #     for j in range(probe_size[1]):
    #         if i != 0 or j != 0:
    #             kernel_temp = np.roll(kernel_full_00, i, axis=0)
    #             kernel_temp = np.roll(kernel_temp, j, axis=1)
    #             kernel_mat[i * probe_size[1] + j, :] = kernel_temp.flatten()

    t0 = time.time()

    edge_val = 1.0

    initial_int = probe[0, 0, 0]
    for i_slice in trange(n_slice):
        this_delta_batch = grid_delta[:, :, :, i_slice]
        this_beta_batch = grid_beta[:, :, :, i_slice]
        # this_delta_batch = np.squeeze(this_delta_batch)
        # this_beta_batch = np.squeeze(this_beta_batch)
        c = np.exp(1j * k * this_delta_batch - k * this_beta_batch)
        probe = probe * c
        # print(probe.shape, kernel.shape)
        # probe = scipy.signal.convolve2d(np.squeeze(probe), kernel, mode='same', boundary='wrap', fillvalue=1)
        # probe = np.reshape(probe, [1, probe.shape[0], probe.shape[1]])

        probe = np.pad(probe, [[0, 0], [pad_len, pad_len], [pad_len, pad_len]],
                       mode='constant',
                       constant_values=edge_val)
        # probe = np.pad(probe, [[0, 0], [pad_len, pad_len], [pad_len, pad_len]], mode='wrap')
        probe = convolve(probe, kernel, mode='valid', axes=([1, 2], [0, 1]))

        # probe = np.reshape(probe, [n_batch, np.prod(probe_size)])
        # probe = probe.dot(kernel_mat.T)
        # probe = np.reshape(probe, [n_batch, *probe_size])

        edge_val = sum(kernel.flatten() * edge_val)
        # print(probe.shape)
        # probe = ifft2(np_ifftshift(np_fftshift(fft2(probe)) * np_fftshift(fft2(kernel))))
        # probe = ifft2(np_ifftshift(np_fftshift(fft2(probe)) * kernel))

        # re-normalize to 1
        # probe *= 1. / np.mean(np.abs(probe))

        probe_array.append(np.abs(probe))

    final_int = probe[0, 0, 0]
    probe *= (initial_int / final_int)

    if free_prop_cm is not None:
        #1dxchange.write_tiff(abs(wavefront), '2d_1024/monitor_output/wv', dtype='float32', overwrite=True)
        if free_prop_cm == 'inf':
            probe = np.fft.fftshift(np.fft.fft2(probe), axes=[1, 2])
        else:
            dist_nm = free_prop_cm * 1e7
            l = np.prod(size_nm)**(1. / 3)
            crit_samp = lmbda_nm * dist_nm / l
            algorithm = 'TF' if mean_voxel_nm > crit_samp else 'IR'
            # print(algorithm)
            algorithm = 'TF'
            if algorithm == 'TF':
                h = get_kernel(dist_nm, lmbda_nm, voxel_nm, grid_shape)
                probe = np.fft.ifft2(
                    np.fft.ifftshift(
                        np.fft.fftshift(np.fft.fft2(probe), axes=[1, 2]) * h,
                        axes=[1, 2]))
            else:
                h = get_kernel_ir(dist_nm, lmbda_nm, voxel_nm, grid_shape)
                probe = np.fft.ifft2(
                    np.fft.ifftshift(
                        np.fft.fftshift(np.fft.fft2(probe), axes=[1, 2]) * h,
                        axes=[1, 2]))

    if debug:
        return probe, probe_array, time.time() - t0
    else:
        return probe
Beispiel #27
0
def optimTrajectory(path, distObs, grid_obs, trajDuration):
    path = np.pad(path, ((6, 6), (0, 0)), mode='edge')
    optim = np.copy(path)
    delta_t = trajDuration / len(path)
    # For each iteration, we optimize between [startOptim, startOptim + OPTIMIZED_POINTS]
    startOptim = INITIAL_OPTIM_OFFSET

    # Optimization configuration
    def objFun(optim, globalPath, startOptim, distObs, delta_t):
        def E(x):
            inp = np.vstack(
                (optim[:startOptim], x.reshape(OPTIMIZED_POINTS, 2),
                 optim[startOptim + OPTIMIZED_POINTS:]))
            return cost(inp, globalPath, startOptim - (6 - OPTIMIZED_POINTS),
                        distObs, delta_t)

        gradE = autograd.grad(E)
        return E, gradE

    epochs = 20

    losses = np.zeros((len(path) - 6 - startOptim - OPTIMIZED_POINTS, epochs))
    while startOptim + OPTIMIZED_POINTS <= len(path) - 6:
        f, df = objFun(optim, path, startOptim, distObs, delta_t)
        x0 = optim[startOptim:startOptim + OPTIMIZED_POINTS].reshape(-1)

        def cb(xk):
            optim[startOptim:startOptim + OPTIMIZED_POINTS, :] = xk.reshape(
                -1, 2)
            g = df(xk)
            print('grad norm', np.linalg.norm(g))
            plot.display(None,
                         None,
                         grid_obs,
                         path,
                         optim,
                         delta_t=delta_t,
                         currentOptimIdx=startOptim,
                         grad=g.reshape(-1, 2),
                         hold=.01)

        result = scipy.optimize.minimize(f,
                                         x0,
                                         method='BFGS',
                                         jac=df,
                                         callback=cb,
                                         options={'gtol': 1e-4})

        print('*' * 30)
        print('Optimization result ({} steps):'.format(result.nit),
              result.success, result.message)
        print('*' * 30)
        optim[startOptim:startOptim + OPTIMIZED_POINTS, :] = result.x.reshape(
            -1, 2)
        plot.display(None,
                     None,
                     grid_obs,
                     path,
                     optim,
                     delta_t=delta_t,
                     currentOptimIdx=startOptim,
                     hold=.1)
        startOptim += 1
    return optim
Beispiel #28
0
def reconstruct_fullfield(fname,
                          theta_st=0,
                          theta_end=PI,
                          n_epochs='auto',
                          crit_conv_rate=0.03,
                          max_nepochs=200,
                          alpha=1e-7,
                          alpha_d=None,
                          alpha_b=None,
                          gamma=1e-6,
                          learning_rate=1.0,
                          output_folder=None,
                          minibatch_size=None,
                          save_intermediate=False,
                          full_intermediate=False,
                          energy_ev=5000,
                          psize_cm=1e-7,
                          n_epochs_mask_release=None,
                          cpu_only=False,
                          save_path='.',
                          shrink_cycle=10,
                          core_parallelization=True,
                          free_prop_cm=None,
                          multiscale_level=1,
                          n_epoch_final_pass=None,
                          initial_guess=None,
                          n_batch_per_update=5,
                          dynamic_rate=True,
                          probe_type='plane',
                          probe_initial=None,
                          probe_learning_rate=1e-3,
                          pupil_function=None,
                          theta_downsample=None,
                          forward_algorithm='fresnel',
                          random_theta=True,
                          object_type='normal',
                          fresnel_approx=True,
                          shared_file_object=False,
                          reweighted_l1=False,
                          **kwargs):
    """
    Reconstruct a beyond depth-of-focus object.
    :param fname: Filename and path of raw data file. Must be in HDF5 format.
    :param theta_st: Starting rotation angle.
    :param theta_end: Ending rotation angle.
    :param n_epochs: Number of epochs to be executed. If given 'auto', optimizer will stop
                     when reduction rate of loss function goes below crit_conv_rate.
    :param crit_conv_rate: Reduction rate of loss function below which the optimizer should
                           stop.
    :param max_nepochs: The maximum number of epochs to be executed if n_epochs is 'auto'.
    :param alpha: Weighting coefficient for both delta and beta regularizer. Should be None
                  if alpha_d and alpha_b are specified.
    :param alpha_d: Weighting coefficient for delta regularizer.
    :param alpha_b: Weighting coefficient for beta regularizer.
    :param gamma: Weighting coefficient for TV regularizer.
    :param learning_rate: Learning rate of ADAM.
    :param output_folder: Name of output folder. Put None for auto-generated pattern.
    :param downsample: Downsampling (not implemented yet).
    :param minibatch_size: Size of minibatch.
    :param save_intermediate: Whether to save the object after each epoch.
    :param energy_ev: Beam energy in eV.
    :param psize_cm: Pixel size in cm.
    :param n_epochs_mask_release: The number of epochs after which the finite support mask
                                  is released. Put None to disable this feature.
    :param cpu_only: Whether to disable GPU.
    :param save_path: The location of finite support mask, the prefix of output_folder and
                      other metadata.
    :param shrink_cycle: Shrink-wrap is executed per every this number of epochs.
    :param core_parallelization: Whether to use Horovod for parallelized computation within
                                 this function.
    :param free_prop_cm: The distance to propagate the wavefront in free space after exiting
                         the sample, in cm.
    :param multiscale_level: The level of multiscale processing. When this number is m and
                             m > 1, m - 1 low-resolution reconstructions will be performed
                             before reconstructing with the original resolution. The downsampling
                             factor for these coarse reconstructions will be [2^(m - 1),
                             2^(m - 2), ..., 2^1].
    :param n_epoch_final_pass: specify a number of iterations for the final pass if multiscale
                               is activated. If None, it will be the same as n_epoch.
    :param initial_guess: supply an initial guess. If None, object will be initialized with noises.
    :param n_batch_per_update: number of minibatches during which gradients are accumulated, after
                               which obj is updated.
    :param dynamic_rate: when n_batch_per_update > 1, adjust learning rate dynamically to allow it
                         to decrease with epoch number
    :param probe_type: type of wavefront. Can be 'plane', '  fixed', or 'optimizable'. If 'optimizable',
                           the probe function will be optimized along with the object.
    :param probe_initial: can be provided for 'optimizable' probe_type, and must be provided for
                              'fixed'.
    """
    def calculate_loss(obj_delta, obj_beta, this_ind_batch, this_prj_batch):

        if not shared_file_object:
            obj_stack = np.stack([obj_delta, obj_beta], axis=3)
            obj_rot_batch = []
            for i in range(minibatch_size):
                obj_rot_batch.append(
                    apply_rotation(
                        obj_stack, coord_ls[this_ind_batch[i]],
                        'arrsize_{}_{}_{}_ntheta_{}'.format(
                            dim_y, dim_x, dim_x, n_theta)))
            obj_rot_batch = np.stack(obj_rot_batch)

            exiting_batch = multislice_propagate_batch_numpy(
                obj_rot_batch[:, :, :, :, 0],
                obj_rot_batch[:, :, :, :, 1],
                probe_real,
                probe_imag,
                energy_ev,
                psize_cm * ds_level,
                free_prop_cm=free_prop_cm,
                obj_batch_shape=[minibatch_size, *this_obj_size],
                kernel=h,
                fresnel_approx=fresnel_approx)
            loss = np.mean((np.abs(exiting_batch) - np.abs(this_prj_batch))**2)

        else:
            exiting_batch = multislice_propagate_batch_numpy(
                obj_delta,
                obj_beta,
                probe_real,
                probe_imag,
                energy_ev,
                psize_cm * ds_level,
                free_prop_cm=free_prop_cm,
                obj_batch_shape=obj_delta.shape,
                kernel=h,
                fresnel_approx=fresnel_approx)
            exiting_batch = exiting_batch[:, safe_zone_width:exiting_batch.
                                          shape[1] - safe_zone_width,
                                          safe_zone_width:exiting_batch.
                                          shape[2] - safe_zone_width]
            loss = np.mean((np.abs(exiting_batch) - np.abs(this_prj_batch))**2)

        dxchange.write_tiff(
            np.squeeze(abs(exiting_batch._value)),
            'cone_256_foam/test_shared_file_object/current/exit_{}'.format(
                rank),
            dtype='float32',
            overwrite=True)
        dxchange.write_tiff(
            np.squeeze(abs(this_prj_batch)),
            'cone_256_foam/test_shared_file_object/current/prj_{}'.format(
                rank),
            dtype='float32',
            overwrite=True)

        reg_term = 0
        if reweighted_l1:
            if alpha_d not in [None, 0]:
                reg_term = reg_term + alpha_d * np.mean(
                    weight_l1 * np.abs(obj_delta))
                loss = loss + reg_term
            if alpha_b not in [None, 0]:
                reg_term = reg_term + alpha_b * np.mean(
                    weight_l1 * np.abs(obj_beta))
                loss = loss + reg_term
        else:
            if alpha_d not in [None, 0]:
                reg_term = reg_term + alpha_d * np.mean(np.abs(obj_delta))
                loss = loss + reg_term
            if alpha_b not in [None, 0]:
                reg_term = reg_term + alpha_b * np.mean(np.abs(obj_beta))
                loss = loss + reg_term
        if gamma not in [None, 0]:
            if shared_file_object:
                reg_term = reg_term + gamma * total_variation_3d(obj_delta,
                                                                 axis_offset=1)
            else:
                reg_term = reg_term + gamma * total_variation_3d(obj_delta,
                                                                 axis_offset=0)
            loss = loss + reg_term

        print('Loss:', loss._value, 'Regularization term:',
              reg_term._value if reg_term != 0 else 0)

        # if alpha_d is None:
        #     reg_term = alpha * (np.sum(np.abs(obj_delta)) + np.sum(np.abs(obj_delta))) + gamma * total_variation_3d(
        #         obj_delta)
        # else:
        #     if gamma == 0:
        #         reg_term = alpha_d * np.sum(np.abs(obj_delta)) + alpha_b * np.sum(np.abs(obj_beta))
        #     else:
        #         reg_term = alpha_d * np.sum(np.abs(obj_delta)) + alpha_b * np.sum(
        #             np.abs(obj_beta)) + gamma * total_variation_3d(obj_delta)
        # loss = loss + reg_term

        # Write convergence data
        f_conv.write('{},{},{},'.format(i_epoch, i_batch, loss._value))
        f_conv.flush()

        return loss

    comm = MPI.COMM_WORLD
    n_ranks = comm.Get_size()
    rank = comm.Get_rank()
    t_zero = time.time()

    # read data
    t0 = time.time()
    print_flush('Reading data...', 0, rank)
    f = h5py.File(os.path.join(save_path, fname), 'r')
    prj_0 = f['exchange/data']
    theta = -np.linspace(theta_st, theta_end, prj_0.shape[0], dtype='float32')
    n_theta = len(theta)
    prj_theta_ind = np.arange(n_theta, dtype=int)
    if theta_downsample is not None:
        prj_0 = prj_0[::theta_downsample]
        theta = theta[::theta_downsample]
        prj_theta_ind = prj_theta_ind[::theta_downsample]
        n_theta = len(theta)
    original_shape = prj_0.shape
    comm.Barrier()
    print_flush('Data reading: {} s'.format(time.time() - t0), 0, rank)
    print_flush('Data shape: {}'.format(original_shape), 0, rank)
    comm.Barrier()

    if output_folder is None:
        output_folder = 'recon_360_minibatch_{}_' \
                        'mskrls_{}_' \
                        'shrink_{}_' \
                        'iter_{}_' \
                        'alphad_{}_' \
                        'alphab_{}_' \
                        'gamma_{}_' \
                        'rate_{}_' \
                        'energy_{}_' \
                        'size_{}_' \
                        'ntheta_{}_' \
                        'prop_{}_' \
                        'ms_{}_' \
                        'cpu_{}' \
            .format(minibatch_size, n_epochs_mask_release, shrink_cycle,
                    n_epochs, alpha_d, alpha_b,
                    gamma, learning_rate, energy_ev,
                    prj_0.shape[-1], prj_0.shape[0], free_prop_cm,
                    multiscale_level, cpu_only)
        if abs(PI - theta_end) < 1e-3:
            output_folder += '_180'

    if save_path != '.':
        output_folder = os.path.join(save_path, output_folder)

    for ds_level in range(multiscale_level - 1, -1, -1):

        initializer_flag = False if ds_level == range(multiscale_level -
                                                      1, -1, -1)[0] else True

        ds_level = 2**ds_level
        print_flush('Multiscale downsampling level: {}'.format(ds_level), 0,
                    rank)
        comm.Barrier()

        # Physical metadata
        voxel_nm = np.array([psize_cm] * 3) * 1.e7 * ds_level
        lmbda_nm = 1240. / energy_ev
        delta_nm = voxel_nm[-1]

        # downsample data
        prj = prj_0
        # prj = np.copy(prj_0)
        # if ds_level > 1:
        #     prj = prj[:, ::ds_level, ::ds_level]
        #     prj = prj.astype('complex64')
        # comm.Barrier()

        dim_y, dim_x = prj.shape[-2] // ds_level, prj.shape[-1] // ds_level
        this_obj_size = [dim_y, dim_x, dim_x]
        comm.Barrier()

        if shared_file_object:
            # Create parallel npy
            if rank == 0:
                try:
                    os.makedirs(os.path.join(output_folder))
                except:
                    print('Target folder {} exists.'.format(output_folder))
                np.save(os.path.join(output_folder, 'intermediate_obj.npy'),
                        np.zeros([*this_obj_size, 2]))
                np.save(os.path.join(output_folder, 'intermediate_m.npy'),
                        np.zeros([*this_obj_size, 2]))
                np.save(os.path.join(output_folder, 'intermediate_v.npy'),
                        np.zeros([*this_obj_size, 2]))
            comm.Barrier()

            # Create memmap pointer on each rank
            dset = np.load(os.path.join(output_folder, 'intermediate_obj.npy'),
                           mmap_mode='r+',
                           allow_pickle=True)
            dset_m = np.load(os.path.join(output_folder, 'intermediate_m.npy'),
                             mmap_mode='r+',
                             allow_pickle=True)
            dset_v = np.load(os.path.join(output_folder, 'intermediate_v.npy'),
                             mmap_mode='r+',
                             allow_pickle=True)

            # Get block allocation
            n_blocks_y, n_blocks_x, n_blocks, block_size = get_block_division(
                this_obj_size, n_ranks)
            print_flush('Number of blocks in y: {}'.format(n_blocks_y), 0,
                        rank)
            print_flush('Number of blocks in x: {}'.format(n_blocks_x), 0,
                        rank)
            print_flush('Block size: {}'.format(block_size), 0, rank)
            probe_pos = []
            # probe_pos is a list of tuples of (line_st, line_end, px_st, ps_end).
            for i_pos in range(n_blocks):
                probe_pos.append(
                    get_block_range(i_pos, n_blocks_x, block_size)[:4])
            probe_pos = np.array(probe_pos)
            if free_prop_cm not in [0, None]:
                safe_zone_width = ceil(4.0 * np.sqrt(
                    (delta_nm * dim_x + free_prop_cm * 1e7) * lmbda_nm) /
                                       (voxel_nm[0]))
            else:
                safe_zone_width = ceil(4.0 * np.sqrt(
                    (delta_nm * dim_x) * lmbda_nm) / (voxel_nm[0]))
            print_flush('safe zone: {}'.format(safe_zone_width), 0, rank)

        # read rotation data
        try:
            coord_ls = read_all_origin_coords(
                'arrsize_{}_{}_{}_ntheta_{}'.format(dim_y, dim_x, dim_x,
                                                    n_theta), n_theta)
        except:
            save_rotation_lookup([dim_y, dim_x, dim_x], n_theta)
            coord_ls = read_all_origin_coords(
                'arrsize_{}_{}_{}_ntheta_{}'.format(dim_y, dim_x, dim_x,
                                                    n_theta), n_theta)

        if minibatch_size is None:
            minibatch_size = n_theta

        if n_epochs_mask_release is None:
            n_epochs_mask_release = np.inf

        if (not shared_file_object) or (shared_file_object and rank == 0):
            try:
                mask = dxchange.read_tiff_stack(
                    os.path.join(save_path, 'fin_sup_mask', 'mask_00000.tiff'),
                    range(prj_0.shape[1]))
            except:
                try:
                    mask = dxchange.read_tiff(
                        os.path.join(save_path, 'fin_sup_mask', 'mask.tiff'))
                except:
                    obj_pr = dxchange.read_tiff_stack(
                        os.path.join(save_path,
                                     'paganin_obj/recon_00000.tiff'),
                        range(prj_0.shape[1]), 5)
                    obj_pr = gaussian_filter(np.abs(obj_pr),
                                             sigma=3,
                                             mode='constant')
                    mask = np.zeros_like(obj_pr)
                    mask[obj_pr > 1e-5] = 1
                    dxchange.write_tiff_stack(mask,
                                              os.path.join(
                                                  save_path,
                                                  'fin_sup_mask/mask'),
                                              dtype='float32',
                                              overwrite=True)
            if ds_level > 1:
                mask = mask[::ds_level, ::ds_level, ::ds_level]
            if shared_file_object:
                np.save(os.path.join(output_folder, 'intermediate_mask.npy'),
                        mask)
        comm.Barrier()

        if shared_file_object:
            dset_mask = np.load(os.path.join(output_folder,
                                             'intermediate_mask.npy'),
                                mmap_mode='r+',
                                allow_pickle=True)

        # unify random seed for all threads
        comm.Barrier()
        seed = int(time.time() / 60)
        np.random.seed(seed)
        comm.Barrier()

        if rank == 0:
            if initializer_flag == False:
                if initial_guess is None:
                    print_flush('Initializing with Gaussian random.', 0, rank)
                    obj_delta = np.random.normal(size=[dim_y, dim_x, dim_x],
                                                 loc=8.7e-7,
                                                 scale=1e-7) * mask
                    obj_beta = np.random.normal(size=[dim_y, dim_x, dim_x],
                                                loc=5.1e-8,
                                                scale=1e-8) * mask
                    obj_delta[obj_delta < 0] = 0
                    obj_beta[obj_beta < 0] = 0
                else:
                    print_flush('Using supplied initial guess.', 0, rank)
                    sys.stdout.flush()
                    obj_delta = initial_guess[0]
                    obj_beta = initial_guess[1]
            else:
                print_flush('Initializing previous pass outcomes.', 0, rank)
                obj_delta = dxchange.read_tiff(
                    os.path.join(output_folder,
                                 'delta_ds_{}.tiff'.format(ds_level * 2)))
                obj_beta = dxchange.read_tiff(
                    os.path.join(output_folder,
                                 'beta_ds_{}.tiff'.format(ds_level * 2)))
                obj_delta = upsample_2x(obj_delta)
                obj_beta = upsample_2x(obj_beta)
                obj_delta += np.random.normal(
                    size=[dim_y, dim_x, dim_x], loc=8.7e-7, scale=1e-7) * mask
                obj_beta += np.random.normal(
                    size=[dim_y, dim_x, dim_x], loc=5.1e-8, scale=1e-8) * mask
                obj_delta[obj_delta < 0] = 0
                obj_beta[obj_beta < 0] = 0
            obj_size = obj_delta.shape
            if object_type == 'phase_only':
                obj_beta[...] = 0
            elif object_type == 'absorption_only':
                obj_delta[...] = 0
            if not shared_file_object:
                np.save('init_delta_temp.npy', obj_delta)
                np.save('init_beta_temp.npy', obj_beta)
            else:
                dset[:, :, :, 0] = obj_delta
                dset[:, :, :, 1] = obj_beta
                dset_m[...] = 0
                dset_v[...] = 0
        comm.Barrier()

        if not shared_file_object:
            obj_delta = np.zeros(this_obj_size)
            obj_beta = np.zeros(this_obj_size)
            obj_delta[:, :, :] = np.load('init_delta_temp.npy',
                                         allow_pickle=True)
            obj_beta[:, :, :] = np.load('init_beta_temp.npy',
                                        allow_pickle=True)
            comm.Barrier()
            if rank == 0:
                os.remove('init_delta_temp.npy')
                os.remove('init_beta_temp.npy')
            comm.Barrier()

        print_flush('Initialzing probe...', 0, rank)
        if not shared_file_object:
            if probe_type == 'plane':
                probe_real = np.ones([dim_y, dim_x])
                probe_imag = np.zeros([dim_y, dim_x])
            elif probe_type == 'optimizable':
                if probe_initial is not None:
                    probe_mag, probe_phase = probe_initial
                    probe_real, probe_imag = mag_phase_to_real_imag(
                        probe_mag, probe_phase)
                else:
                    # probe_mag = np.ones([dim_y, dim_x])
                    # probe_phase = np.zeros([dim_y, dim_x])
                    back_prop_cm = (free_prop_cm + (psize_cm * obj_size[2])
                                    ) if free_prop_cm is not None else (
                                        psize_cm * obj_size[2])
                    probe_init = create_probe_initial_guess(
                        os.path.join(save_path, fname), back_prop_cm * 1.e7,
                        energy_ev, psize_cm * 1.e7)
                    probe_real = probe_init.real
                    probe_imag = probe_init.imag
                if pupil_function is not None:
                    probe_real = probe_real * pupil_function
                    probe_imag = probe_imag * pupil_function
            elif probe_type == 'fixed':
                probe_mag, probe_phase = probe_initial
                probe_real, probe_imag = mag_phase_to_real_imag(
                    probe_mag, probe_phase)
            elif probe_type == 'point':
                # this should be in spherical coordinates
                probe_real = np.ones([dim_y, dim_x])
                probe_imag = np.zeros([dim_y, dim_x])
            elif probe_type == 'gaussian':
                probe_mag_sigma = kwargs['probe_mag_sigma']
                probe_phase_sigma = kwargs['probe_phase_sigma']
                probe_phase_max = kwargs['probe_phase_max']
                py = np.arange(obj_size[0]) - (obj_size[0] - 1.) / 2
                px = np.arange(obj_size[1]) - (obj_size[1] - 1.) / 2
                pxx, pyy = np.meshgrid(px, py)
                probe_mag = np.exp(-(pxx**2 + pyy**2) /
                                   (2 * probe_mag_sigma**2))
                probe_phase = probe_phase_max * np.exp(
                    -(pxx**2 + pyy**2) / (2 * probe_phase_sigma**2))
                probe_real, probe_imag = mag_phase_to_real_imag(
                    probe_mag, probe_phase)
            else:
                raise ValueError(
                    'Invalid wavefront type. Choose from \'plane\', \'fixed\', \'optimizable\'.'
                )
        else:
            if probe_type == 'plane':
                probe_real = np.ones([block_size + 2 * safe_zone_width] * 2)
                probe_imag = np.zeros([block_size + 2 * safe_zone_width] * 2)
            else:
                raise ValueError(
                    'probe_type other than plane is not yet supported with shared file object.'
                )

        # =============finite support===================
        if not shared_file_object:
            obj_delta = obj_delta * mask
            obj_beta = obj_beta * mask
            obj_delta = np.clip(obj_delta, 0, None)
            obj_beta = np.clip(obj_beta, 0, None)
        # ==============================================

        # generate Fresnel kernel
        if not shared_file_object:
            h = get_kernel(delta_nm,
                           lmbda_nm,
                           voxel_nm, [dim_y, dim_y, dim_x],
                           fresnel_approx=fresnel_approx)
        else:
            h = get_kernel(delta_nm,
                           lmbda_nm,
                           voxel_nm, [
                               block_size + safe_zone_width * 2,
                               block_size + safe_zone_width * 2, dim_x
                           ],
                           fresnel_approx=fresnel_approx)

        loss_grad = grad(calculate_loss, [0, 1])

        # Save convergence data
        try:
            os.makedirs(os.path.join(output_folder, 'convergence'))
        except:
            pass
        f_conv = open(
            os.path.join(output_folder, 'convergence',
                         'loss_rank_{}.txt'.format(rank)), 'w')
        f_conv.write('i_epoch,i_batch,loss,time\n')

        print_flush('Optimizer started.', 0, rank)
        if rank == 0:
            create_summary(output_folder, locals(), preset='fullfield')

        cont = True
        i_epoch = 0
        while cont:
            if shared_file_object:
                # Do a ptychography-like allocation.
                n_pos = len(probe_pos)
                n_spots = n_theta * n_pos
                n_tot_per_batch = minibatch_size * n_ranks
                n_batch = int(np.ceil(float(n_spots) / n_tot_per_batch))
                spots_ls = range(n_spots)
                ind_list_rand = []

                theta_ls = np.arange(n_theta)
                np.random.shuffle(theta_ls)

                for i, i_theta in enumerate(theta_ls):
                    spots_ls = range(n_pos)
                    if n_pos % minibatch_size != 0:
                        # Append randomly selected diffraction spots if necessary, so that a rank won't be given
                        # spots from different angles in one batch.
                        spots_ls = np.append(
                            spots_ls,
                            np.random.choice(
                                spots_ls[:-(n_pos % minibatch_size)],
                                minibatch_size - (n_pos % minibatch_size),
                                replace=False))
                    if i == 0:
                        ind_list_rand = np.vstack(
                            [np.array([i_theta] * len(spots_ls)),
                             spots_ls]).transpose()
                    else:
                        ind_list_rand = np.concatenate([
                            ind_list_rand,
                            np.vstack([
                                np.array([i_theta] * len(spots_ls)), spots_ls
                            ]).transpose()
                        ],
                                                       axis=0)
                ind_list_rand = split_tasks(ind_list_rand, n_tot_per_batch)
                probe_size_half = block_size // 2 + safe_zone_width
            else:
                ind_list_rand = np.arange(n_theta)
                np.random.shuffle(ind_list_rand)
                n_tot_per_batch = n_ranks * minibatch_size
                if n_theta % n_tot_per_batch > 0:
                    ind_list_rand = np.concatenate([
                        ind_list_rand, ind_list_rand[:n_tot_per_batch -
                                                     n_theta % n_tot_per_batch]
                    ])
                ind_list_rand = split_tasks(ind_list_rand, n_tot_per_batch)
                ind_list_rand = [np.sort(x) for x in ind_list_rand]

            m, v = (None, None)
            t0 = time.time()
            for i_batch in range(len(ind_list_rand)):

                t00 = time.time()
                if not shared_file_object:
                    this_ind_batch = ind_list_rand[i_batch][rank *
                                                            minibatch_size:
                                                            (rank + 1) *
                                                            minibatch_size]
                    this_prj_batch = prj[
                        this_ind_batch, ::ds_level, ::ds_level]
                else:
                    if len(ind_list_rand[i_batch]) < n_tot_per_batch:
                        n_supp = n_tot_per_batch - len(ind_list_rand[i_batch])
                        ind_list_rand[i_batch] = np.concatenate([
                            ind_list_rand[i_batch], ind_list_rand[0][:n_supp]
                        ])

                    this_ind_batch = ind_list_rand[i_batch]
                    this_i_theta = this_ind_batch[rank * minibatch_size, 0]
                    this_ind_rank = this_ind_batch[rank *
                                                   minibatch_size:(rank + 1) *
                                                   minibatch_size, 1]

                    this_prj_batch = []
                    for i_pos in this_ind_rank:
                        line_st, line_end, px_st, px_end = probe_pos[i_pos]
                        line_st_0 = max([0, line_st])
                        line_end_0 = min([dim_y, line_end])
                        px_st_0 = max([0, px_st])
                        px_end_0 = min([dim_x, px_end])
                        patch = prj[this_i_theta, ::ds_level, ::ds_level][
                            line_st_0:line_end_0, px_st_0:px_end_0]
                        if line_st < 0:
                            patch = np.pad(patch, [[-line_st, 0], [0, 0]],
                                           mode='constant')
                        if line_end > dim_y:
                            patch = np.pad(patch,
                                           [[0, line_end - dim_y], [0, 0]],
                                           mode='constant')
                        if px_st < 0:
                            patch = np.pad(patch, [[0, 0], [-px_st, 0]],
                                           mode='constant')
                        if px_end > dim_x:
                            patch = np.pad(patch,
                                           [[0, 0], [0, px_end - dim_x]],
                                           mode='constant')
                        this_prj_batch.append(patch)
                    this_prj_batch = np.array(this_prj_batch)
                    this_pos_batch = probe_pos[this_ind_rank]
                    this_pos_batch_safe = this_pos_batch + np.array([
                        -safe_zone_width, safe_zone_width, -safe_zone_width,
                        safe_zone_width
                    ])
                    # if ds_level > 1:
                    #     this_prj_batch = this_prj_batch[:, :, ::ds_level, ::ds_level]
                    comm.Barrier()

                    # Get values for local chunks of object_delta and beta; interpolate and read directly from HDF5
                    obj = get_rotated_subblocks(dset, this_pos_batch_safe,
                                                coord_ls[this_i_theta], None)
                    obj_delta = np.array(obj[:, :, :, :, 0])
                    obj_beta = np.array(obj[:, :, :, :, 1])
                    m = get_rotated_subblocks(dset_m, this_pos_batch,
                                              coord_ls[this_i_theta], None)
                    m = np.array([m[:, :, :, :, 0], m[:, :, :, :, 1]])
                    m_0 = np.copy(m)
                    v = get_rotated_subblocks(dset_v, this_pos_batch,
                                              coord_ls[this_i_theta], None)
                    v = np.array([v[:, :, :, :, 0], v[:, :, :, :, 1]])
                    v_0 = np.copy(v)
                    mask = get_rotated_subblocks(dset_mask,
                                                 this_pos_batch,
                                                 coord_ls[this_i_theta],
                                                 None,
                                                 monochannel=True)

                    mask_0 = np.copy(mask)

                # Update weight for reweighted L1
                if i_batch % 10 == 0 and i_epoch >= 1:
                    weight_l1 = np.max(obj_delta) / (abs(obj_delta) + 1e-8)
                else:
                    weight_l1 = np.ones_like(obj_delta)

                grads = loss_grad(obj_delta, obj_beta, this_ind_batch,
                                  this_prj_batch)
                if not shared_file_object:
                    this_grads = np.array(grads)
                    grads = np.zeros_like(this_grads)
                    comm.Allreduce(this_grads, grads)
                # grads = comm.allreduce(this_grads)
                grads = np.array(grads)
                grads = grads / n_ranks

                if shared_file_object:
                    grads = grads[:, :,
                                  safe_zone_width:safe_zone_width + block_size,
                                  safe_zone_width:safe_zone_width +
                                  block_size, :]
                    obj_delta = obj_delta[:,
                                          safe_zone_width:obj_delta.shape[1] -
                                          safe_zone_width,
                                          safe_zone_width:obj_delta.shape[2] -
                                          safe_zone_width, :]
                    obj_beta = obj_beta[:, safe_zone_width:obj_beta.shape[1] -
                                        safe_zone_width,
                                        safe_zone_width:obj_beta.shape[2] -
                                        safe_zone_width, :]

                (obj_delta, obj_beta), m, v = apply_gradient_adam(
                    np.array([obj_delta, obj_beta]),
                    grads,
                    i_batch,
                    m,
                    v,
                    step_size=learning_rate)

                # finite support
                obj_delta = obj_delta * mask
                obj_beta = obj_beta * mask
                obj_delta = np.clip(obj_delta, 0, None)
                obj_beta = np.clip(obj_beta, 0, None)

                # shrink wrap

                if shrink_cycle is not None:
                    if i_batch % shrink_cycle == 0 and i_batch > 0:
                        boolean = obj_delta > 1e-12
                        boolean = boolean.astype('float')
                        if not shared_file_object:
                            mask = mask * boolean.astype('float')
                        if shared_file_object:
                            write_subblocks_to_file(dset_mask,
                                                    this_pos_batch,
                                                    boolean,
                                                    None,
                                                    coord_ls[this_i_theta],
                                                    probe_size_half,
                                                    mask=True)

                if shared_file_object:
                    obj = obj[:,
                              safe_zone_width:obj.shape[1] - safe_zone_width,
                              safe_zone_width:obj.shape[2] -
                              safe_zone_width, :, :]
                    obj_delta = obj_delta - obj[:, :, :, :, 0]
                    obj_beta = obj_beta - obj[:, :, :, :, 1]
                    obj_delta = obj_delta / n_ranks
                    obj_beta = obj_beta / n_ranks
                    write_subblocks_to_file(dset, this_pos_batch, obj_delta,
                                            obj_beta, coord_ls[this_i_theta],
                                            probe_size_half)
                    m = m - m_0
                    m /= n_ranks
                    write_subblocks_to_file(dset_m, this_pos_batch, m[0], m[1],
                                            coord_ls[this_i_theta],
                                            probe_size_half)
                    v = v - v_0
                    v /= n_ranks
                    write_subblocks_to_file(dset_v, this_pos_batch, v[0], v[1],
                                            coord_ls[this_i_theta],
                                            probe_size_half)

                if rank == 0:
                    if shared_file_object:
                        # dxchange.write_tiff(dset[:, :, :, 0],
                        #                     fname=os.path.join(output_folder, 'intermediate', 'current'.format(ds_level)),
                        #                     dtype='float32', overwrite=True)
                        dxchange.write_tiff(
                            dset[:, :, :, 0],
                            fname=os.path.join(
                                output_folder,
                                'current/delta_{}'.format(i_batch)),
                            dtype='float32',
                            overwrite=True)
                    else:
                        dxchange.write_tiff(obj_delta,
                                            fname=os.path.join(
                                                output_folder, 'intermediate',
                                                'current'.format(ds_level)),
                                            dtype='float32',
                                            overwrite=True)

                print_flush('Minibatch done in {} s (rank {})'.format(
                    time.time() - t00, rank))

                f_conv.write('{}\n'.format(time.time() - t_zero))
                f_conv.flush()

            if n_epochs == 'auto':
                pass
            else:
                if i_epoch == n_epochs - 1: cont = False
            i_epoch = i_epoch + 1

            # print_flush(
            #     'Epoch {} (rank {}); loss = {}; Delta-t = {} s; current time = {}.'.format(i_epoch, rank,
            #                                                         calculate_loss(obj_delta, obj_beta, this_ind_batch,
            #                                                                        this_prj_batch),
            #                                                         time.time() - t0, time.time() - t_zero))
            if rank == 0:
                if shared_file_object:
                    dxchange.write_tiff(dset[:, :, :, 0],
                                        fname=os.path.join(
                                            output_folder,
                                            'delta_ds_{}'.format(ds_level)),
                                        dtype='float32',
                                        overwrite=True)
                    dxchange.write_tiff(dset[:, :, :, 1],
                                        fname=os.path.join(
                                            output_folder,
                                            'beta_ds_{}'.format(ds_level)),
                                        dtype='float32',
                                        overwrite=True)
                else:
                    dxchange.write_tiff(obj_delta,
                                        fname=os.path.join(
                                            output_folder,
                                            'delta_ds_{}'.format(ds_level)),
                                        dtype='float32',
                                        overwrite=True)
                    dxchange.write_tiff(obj_beta,
                                        fname=os.path.join(
                                            output_folder,
                                            'beta_ds_{}'.format(ds_level)),
                                        dtype='float32',
                                        overwrite=True)

        print_flush('Current iteration finished.', 0, rank)
Beispiel #29
0
    def calculate_loss(obj_delta, obj_beta, this_i_theta, this_pos_batch,
                       this_prj_batch):

        obj_stack = np.stack([obj_delta, obj_beta], axis=3)
        obj_rot = apply_rotation(
            obj_stack, coord_ls[this_i_theta],
            'arrsize_{}_{}_{}_ntheta_{}'.format(*this_obj_size, n_theta))
        probe_pos_batch_ls = []
        exiting_ls = []
        i_dp = 0
        while i_dp < minibatch_size:
            probe_pos_batch_ls.append(
                this_pos_batch[i_dp:min([i_dp + n_dp_batch, minibatch_size])])
            i_dp += n_dp_batch

        # pad if needed
        pad_arr = np.array([[0, 0], [0, 0]])
        if probe_pos[:, 0].min() - probe_size_half[0] < 0:
            pad_len = probe_size_half[0] - probe_pos[:, 0].min()
            obj_rot = np.pad(obj_rot, ((pad_len, 0), (0, 0), (0, 0), (0, 0)),
                             mode='constant')
            pad_arr[0, 0] = pad_len
        if probe_pos[:, 0].max() + probe_size_half[0] > this_obj_size[0]:
            pad_len = probe_pos[:, 0].max(
            ) + probe_size_half[0] - this_obj_size[0]
            obj_rot = np.pad(obj_rot, ((0, pad_len), (0, 0), (0, 0), (0, 0)),
                             mode='constant')
            pad_arr[0, 1] = pad_len
        if probe_pos[:, 1].min() - probe_size_half[1] < 0:
            pad_len = probe_size_half[1] - probe_pos[:, 1].min()
            obj_rot = np.pad(obj_rot, ((0, 0), (pad_len, 0), (0, 0), (0, 0)),
                             mode='constant')
            pad_arr[1, 0] = pad_len
        if probe_pos[:, 1].max() + probe_size_half[1] > this_obj_size[1]:
            pad_len = probe_pos[:, 1].max(
            ) + probe_size_half[0] - this_obj_size[1]
            obj_rot = np.pad(obj_rot, ((0, 0), (0, pad_len), (0, 0), (0, 0)),
                             mode='constant')
            pad_arr[1, 1] = pad_len

        for k, pos_batch in enumerate(probe_pos_batch_ls):
            subobj_ls = []
            for j in range(len(pos_batch)):
                pos = pos_batch[j]
                pos = [int(x) for x in pos]
                pos[0] = pos[0] + pad_arr[0, 0]
                pos[1] = pos[1] + pad_arr[1, 0]
                subobj = obj_rot[pos[0] - probe_size_half[0]:pos[0] -
                                 probe_size_half[0] + probe_size[0],
                                 pos[1] - probe_size_half[1]:pos[1] -
                                 probe_size_half[1] + probe_size[1], :, :]
                subobj_ls.append(subobj)

            subobj_ls = np.stack(subobj_ls)
            exiting = multislice_propagate_cnn(subobj_ls[:, :, :, :, 0],
                                               subobj_ls[:, :, :, :, 1],
                                               probe_real,
                                               probe_imag,
                                               energy_ev,
                                               [psize_cm * ds_level] * 3,
                                               free_prop_cm='inf')
            exiting_ls.append(exiting)
        exiting_ls = np.concatenate(exiting_ls, 0)
        loss = np.mean((np.abs(exiting_ls) - np.abs(this_prj_batch))**2)

        return loss
 def set_list_quantiles(self, features):
     self.list_quantiles = list(np.pad(np.array([np.quantile(features, (i + 1) / self.number_quantiles, axis=0) \
                                                 for i in range(self.number_quantiles - 1)]), 1, 'constant', \
                                       constant_values=(-np.inf, np.inf)).T)[1:-1]
     self.initialized = True
Beispiel #31
0
    def apply_poly(self, x_poly, lst_poly):
        res = Poly()

        k_h, k_w = self.kernel

        lw, up = x_poly.lw.copy(), x_poly.up.copy()

        lw = lw.reshape(x_poly.shape)
        up = up.reshape(x_poly.shape)

        p = self.padding
        lw_pad = np.pad(lw, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant')
        up_pad = np.pad(up, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant')
        x_n, x_c, x_h, x_w = lw_pad.shape

        res_h = int((x_h - k_h) / self.stride) + 1
        res_w = int((x_w - k_w) / self.stride) + 1

        len_pad = x_c * x_h * x_w
        len_res = x_c * res_h * res_w

        c_idx, h_idx, w_idx = index2d(x_c, self.stride, self.kernel,
                                      (x_h, x_w))

        res_lw, res_up = [], []
        mx_lw_idx_lst, mx_up_idx_lst = [], []

        for c in range(x_c):
            for i in range(res_h * res_w):

                mx_lw_val, mx_lw_idx = -1e9, None

                for k in range(k_h * k_w):

                    h, w = h_idx[k, i], w_idx[k, i]
                    val = lw_pad[0, c, h, w]

                    if val > mx_lw_val:
                        mx_lw_val, mx_lw_idx = val, (c, h, w)

                mx_up_val, cnt = -1e9, 0

                for k in range(k_h * k_w):

                    h, w = h_idx[k, i], w_idx[k, i]
                    val = up_pad[0, c, h, w]

                    if val > mx_up_val:
                        mx_up_val = val

                    if mx_lw_idx != (c, h, w) and val > mx_lw_val:
                        cnt += 1

                res_lw.append(mx_lw_val)
                res_up.append(mx_up_val)

                mx_lw_idx_lst.append(mx_lw_idx)
                if cnt > 0: mx_up_idx_lst.append(None)
                else: mx_up_idx_lst.append(mx_lw_idx)

        res.lw = np.array(res_lw)
        res.up = np.array(res_up)

        res.le = np.zeros([len_res, len_pad + 1])
        res.ge = np.zeros([len_res, len_pad + 1])

        res.shape = (1, x_c, res_h, res_w)

        for i in range(len_res):
            c = mx_lw_idx_lst[i][0]
            h = mx_lw_idx_lst[i][1]
            w = mx_lw_idx_lst[i][2]

            idx = c * x_h * x_w + h * x_w + w
            res.ge[i, idx] = 1

            if mx_up_idx_lst[i] is None:
                res.le[i, -1] = res.up[i]
            else:
                res.le[i, idx] = 1

        del_idx = []
        if self.padding > 0:
            del_idx = del_idx + list(range(self.padding * (x_w + 1)))
            mx = x_h - self.padding
            for i in range(self.padding + 1, mx):
                tmp = i * x_w
                del_idx = del_idx + list(
                    range(tmp - self.padding, tmp + self.padding))
            del_idx = del_idx + list(range(mx * x_h - self.padding, x_h * x_w))

            tmp = np.array(del_idx)

            for i in range(1, x_c):
                offset = i * x_h * x_w
                del_idx = del_idx + list((tmp + offset).copy())

        res.le = np.delete(res.le, del_idx, 1)
        res.ge = np.delete(res.ge, del_idx, 1)

        return res
y_true = np.abs(np.random.rand(Ndata))

#w_true = np.abs(np.random.rand(Ndata))+1;

#true grid needs to be set up with noise
w_true_grid = np.zeros((n_grid, n_grid))
for x, y, w in zip(x_true, y_true, w_true):
    w_true_grid[np.argmin(np.abs(theta_grid - x)),
                np.argmin(np.abs(theta_grid - y))] = w

data4 = convolvesame_fft(w_true_grid,
                         psf) + sig_noise  # * np.random.randn(n_grid,n_grid);
data2 = Psi(w_true_grid) + sig_noise  # * np.random.randn(n_grid,n_grid);
data3 = np.real(
    fft.ifft2(
        fft.fft2(np.pad(w_true_grid, ((5, 0), (5, 0)), 'constant')) *
        fft.fft2(np.pad(psf, ((5, 0), (5, 0)), 'constant')))
) + sig_noise  # * np.random.randn(n_grid,n_grid);
data = np.real(
    fft.ifft2(fft.fft2(w_true_grid) *
              fft.fft2(psf))) + sig_noise * np.random.randn(n_grid, n_grid)
#print(data-data2);
'''
fig, ax = plt.subplots(1,2)
ax[0].imshow(w_true_grid);
ax[0].set_title('True Positions')
#ax[1].imshow(data3[:-4,:-4]);
ax[1].imshow(data4);
ax[1].set_title('Observed Data')
plt.show();
'''