Ejemplo n.º 1
0
 def shuffle_data(self):
     # shuffle before split data
     n_data = self.features.shape[0]
     shuffle_idx = np.random.permutation(n_data)
     # do two inplace shuffle operations
     self.features = np.take(self.features, shuffle_idx, axis=0)
     self.labels = np.take(self.labels, shuffle_idx, axis=0)
Ejemplo n.º 2
0
 def diagonal(self, X):
     d1 = self.kernel1.dimension
     X_shape = getval(X.shape)
     X1 = anp.take(X, range(0, d1), axis=1)
     X2 = anp.take(X, range(d1, X_shape[1]), axis=1)
     diag1 = self.kernel1.diagonal(X1)
     diag2 = self.kernel2.diagonal(X2)
     return diag1 * diag2
Ejemplo n.º 3
0
    def forward(self, x):
        """
        Actual computation of the warping transformation (see details above)

        :param x: input data of size (n,1)
        """
        warping = self._warping()
        warping_a = anp.take(warping, [0], axis=0)
        warping_b = anp.take(warping, [1], axis=0)

        return 1. - anp.power(1. - anp.power(self._rescale(x), warping_a),
                              warping_b)
Ejemplo n.º 4
0
    def forward(self, X1, X2):
        d1 = self.kernel1.dimension
        X1_shape = getval(X1.shape)
        X2_shape = getval(X2.shape)
        X1_1 = anp.take(X1, range(0, d1), axis=1)
        X1_2 = anp.take(X1, range(d1, X1_shape[1]), axis=1)
        X2_1 = anp.take(X2, range(0, d1), axis=1)
        X2_2 = anp.take(X2, range(d1, X2_shape[1]), axis=1)

        kmat1 = self.kernel1(X1_1, X2_1)
        kmat2 = self.kernel2(X1_2, X2_2)
        return kmat1 * kmat2
Ejemplo n.º 5
0
def restore_checkpoint(output_folder, shared_file_object=True, optimizer=None):

    i_epoch, i_batch = [
        int(i)
        for i in np.loadtxt(os.path.join(output_folder, 'checkpoint.txt'))
    ]
    if not shared_file_object:
        obj = np.load(os.path.join(output_folder, 'obj_checkpoint.npy'))
        obj_delta = np.take(obj, 0, axis=-1)
        obj_beta = np.take(obj, 1, axis=-1)
        optimizer.restore_param_arrays_from_checkpoint()
        return i_epoch, i_batch, obj_delta, obj_beta
    else:
        return i_epoch, i_batch
Ejemplo n.º 6
0
def draw_zl1_ys(z_s, py_zl1, M):
    ''' Draw from p(z1 | y, s) proportional to p(y | z1) * p(z1 | s) for all s 
    z_s (list of nd-arrays): zl | s^l for all s^l and all l.
    py_zl1 (nd-array): p(y | z1_M) 
    M (list of int): The number of MC points on all layers
    ------------------------------------------------------------------------
    returns ((M1, numobs, r1, S1) nd-array): z^{(1)} | y, s
    '''

    epsilon = 1E-16

    numobs = py_zl1.shape[1]
    L = len(z_s) - 1
    S = [z_s[l].shape[2] for l in range(L)]
    r = [z_s[l].shape[1] for l in range(L + 1)]

    norm_cste = np.sum(py_zl1, axis=0, keepdims=True)
    norm_cste = np.where(norm_cste <= epsilon, epsilon, norm_cste)

    py_zl1_norm = py_zl1 / norm_cste

    zl1_ys = np.zeros((M[0], numobs, r[0], S[0]))
    for s in range(S[0]):
        qM_cum = py_zl1_norm[:, :, s].T.cumsum(axis=1)
        u = np.random.rand(numobs, 1, M[0])

        choices = u < qM_cum[..., np.newaxis]
        idx = choices.argmax(1)

        zl1_ys[:, :, :, s] = np.take(z_s[0][:, :, s], idx.T, axis=0)

    return zl1_ys
def stress(params, positions, cell, strain=np.zeros((3, 3))):
    """Compute the stress on a Lennard-Jones system.

    Parameters
    ----------

    params : dictionary of paramters.
      Defaults to {'sigma': 1.0, 'epsilon': 1.0}

    positions : array of floats. Shape = (natoms, 3)

    cell: array of unit cell vectors. Shape = (3, 3)

    Returns
    -------
    stress : an array of stress components. Shape = (6,)
    [sxx, syy, szz, syz, sxz, sxy]

    """
    dEdst = elementwise_grad(energy, 3)

    volume = np.abs(np.linalg.det(cell))

    der = dEdst(params, positions, cell, strain)
    result = (der + der.T) / 2 / volume
    return np.take(result, [0, 4, 8, 5, 2, 1])
Ejemplo n.º 8
0
    def fit(self, X, y, X_valid, y_valid):
        for epoch in range(self.num_of_epochs):
            print("epoch number: " + str(epoch + 1))
            permuted_indices = np.random.permutation(X.shape[0])
            for i in range(0, X.shape[0], self.batch_size):
                selected_data_points = np.take(permuted_indices, range(i, i+self.batch_size), mode='wrap')
                delta_w = self._d_cost(X[selected_data_points], y[selected_data_points], self.weights)
                for w, d in zip(self.weights, delta_w):
                    w -= d*self.learning_rate
                for i in range(len(self.weights)):
                    self.ema_weights[i] = self.ema_weights[i]*self.ema + self.weights[i]*(1-self.ema)

            training_accuracy = compute_accuracy(self.predict(X, self.weights), np.argmax(y, 1))
            validation_accuracy = compute_accuracy(self.predict(X_valid, self.weights), np.argmax(y_valid, 1))

            print("training accuracy: " + str(round(training_accuracy, 2)))
            print("validation accuracy: " + str(round(validation_accuracy, 2)))

            print("cost: " + str(self._cost(X, y, self.weights)))

            training_accuracy = compute_accuracy(self.predict(X, self.ema_weights), np.argmax(y, 1))
            validation_accuracy = compute_accuracy(self.predict(X_valid, self.ema_weights), np.argmax(y_valid, 1))

            print("training accuracy ema: " + str(round(training_accuracy, 2)))
            print("validation accuracy ema: " + str(round(validation_accuracy, 2)))

            self.save_average_and_std(X)
Ejemplo n.º 9
0
    def train(self, niter, wake_data, bar=True, snap=False):

        batch_data = OrderedDict()

        if bar:
            r = tqdm(range(niter), leave=True, desc="training", ncols=100)
        else:
            r = range(niter)

        N = list(wake_data.values())[0].shape[0]
        epoch = 0

        for i in r:

            for k, v in wake_data.items():
                batch_data[k] = np.take(v,
                                        range(self.nwake * i,
                                              self.nwake * (i + 1)),
                                        axis=0,
                                        mode="wrap")
                if i * self.nwake >= N * (epoch + 1):
                    shuffle_dict(wake_data)
                    epoch += 1

            self.sleep()
            self.wake(batch_data, i)
            if snap and i % snap == 0:
                self.save()

        self.niter = niter
        self.save()
Ejemplo n.º 10
0
    def fit(self, X, y, X_valid, y_valid):
        for epoch in range(self.num_of_epochs):
            print("epoch number: " + str(epoch + 1))
            permuted_indices = np.random.permutation(X.shape[0])
            for i in range(0, X.shape[0], self.batch_size):
                selected_data_points = np.take(permuted_indices,
                                               range(i, i + self.batch_size),
                                               mode='wrap')
                delta_w = self._d_cost(X[selected_data_points],
                                       y[selected_data_points], self.weights)
                self.weights -= delta_w * self.learning_rate

            training_accuracy = compute_accuracy(self.predict(X),
                                                 np.argmax(y, 1))
            validation_accuracy = compute_accuracy(self.predict(X_valid),
                                                   np.argmax(y_valid, 1))

            print("training accuracy: " + str(round(training_accuracy, 2)))
            print("validation accuracy: " + str(round(validation_accuracy, 2)))

            print("cost: " + str(self._cost(X, y, self.weights)))

            if self.validation_accuracy < validation_accuracy:
                self.validation_accuracy = validation_accuracy
                self.old_weights = self.weights
            else:
                self.weights = self.old_weights
                self.learning_rate = 0.5 * self.learning_rate
Ejemplo n.º 11
0
def stress(parameters, positions, numbers, cell, strain=np.zeros((3, 3))):
  """Compute the stress on an EMT system.

    Parameters
    ----------

    positions : array of floats. Shape = (natoms, 3)

    numbers : array of integers of atomic numbers (natoms,)

    cell: array of unit cell vectors. Shape = (3, 3)

    Returns
    -------
    stress : an array of stress components. Shape = (6,)
    [sxx, syy, szz, syz, sxz, sxy]

    """
  dEdst = elementwise_grad(energy, 4)

  volume = np.abs(np.linalg.det(cell))

  der = dEdst(parameters, positions, numbers, cell, strain)
  result = (der + der.T) / 2 / volume
  return np.take(result, [0, 4, 8, 5, 2, 1])
  def test_stress(self):
    for structure in ('fcc', 'bcc', 'hcp', 'diamond', 'sc'):
      for repeat in ((1, 1, 1), (1, 2, 3)):
        for a in [3.0, 4.0]:
          atoms = bulk('Cu', structure, a=a).repeat(repeat)
          atoms.rattle()
          atoms.set_calculator(EMT())

          # Numerically calculate the ase stress
          d = 1e-9  # a delta strain

          ase_stress = np.empty((3, 3)).flatten()
          cell0 = atoms.cell

          # Use a finite difference approach that is centered.
          for i in [0, 4, 8, 5, 2, 1]:
            strain_tensor = np.zeros((3, 3))
            strain_tensor = strain_tensor.flatten()
            strain_tensor[i] = d
            strain_tensor = strain_tensor.reshape((3, 3))
            strain_tensor += strain_tensor.T
            strain_tensor /= 2
            strain_tensor += np.eye(3, 3)

            cell = np.dot(strain_tensor, cell0.T).T
            positions = np.dot(strain_tensor, atoms.positions.T).T
            atoms.cell = cell
            atoms.positions = positions
            ep = atoms.get_potential_energy()

            strain_tensor = np.zeros((3, 3))
            strain_tensor = strain_tensor.flatten()
            strain_tensor[i] = -d
            strain_tensor = strain_tensor.reshape((3, 3))
            strain_tensor += strain_tensor.T
            strain_tensor /= 2
            strain_tensor += np.eye(3, 3)

            cell = np.dot(strain_tensor, cell0.T).T
            positions = np.dot(strain_tensor, atoms.positions.T).T
            atoms.cell = cell
            atoms.positions = positions
            em = atoms.get_potential_energy()

            ase_stress[i] = (ep - em) / (2 * d) / atoms.get_volume()

          ase_stress = np.take(ase_stress.reshape((3, 3)), [0, 4, 8, 5, 2, 1])

          ag_stress = stress(parameters, atoms.positions, atoms.numbers, atoms.cell)

          # I picked the 0.03 tolerance here. I thought it should be closer, but
          # it is a simple numerical difference I am using for the derivative,
          # and I am not sure it is totally correct.
          self.assertTrue(np.all(np.abs(ase_stress - ag_stress) <= 0.03),
                          f'''
ase: {ase_stress}
ag : {ag_stress}')
diff {ase_stress - ag_stress}
''')
Ejemplo n.º 13
0
 def get_mini_batch_round_robins(self, dim, col):
     observed = self.tensor.find_observed_ui(dim, col)
     if len(observed) > self.batch_size:
         observed_idx = np.random.choice(len(observed), self.batch_size, replace=False)
         observed_subset = np.take(observed, observed_idx, axis=0)
     else:
         observed_subset = observed
     return observed_subset, len(observed)
Ejemplo n.º 14
0
    def write_params_to_file(self,
                             this_pos_batch=None,
                             probe_size=None,
                             n_ranks=1):

        for param_name, p in self.params_chunk_array_dict.items():
            p = p - self.params_chunk_array_0_dict[param_name]
            p /= n_ranks
            dset_p = self.params_dset_dict[param_name]
            write_subblocks_to_file(dset_p,
                                    this_pos_batch,
                                    np.take(p, 0, axis=-1),
                                    np.take(p, 1, axis=-1),
                                    probe_size,
                                    self.whole_object_size[:-1],
                                    monochannel=False)
        return
Ejemplo n.º 15
0
    def _compute_terms(self, X, alpha, mean_lam, gamma, delta, ret_mean=False):
        dim = self.kernel_x.dimension
        X_shape = getval(X.shape)
        cfg = anp.take(X, range(0, dim), axis=1)
        res = anp.take(X, range(dim, X_shape[1]), axis=1)
        kappa = self._compute_kappa(res, alpha, mean_lam)
        kr_pref = anp.reshape(gamma, (1, 1))

        if ret_mean or (self.encoding_delta is not None) or delta > 0.0:
            mean = self.mean_x(cfg)
        else:
            mean = None
        if self.encoding_delta is not None:
            kr_pref = anp.subtract(kr_pref, anp.multiply(delta, mean))
        elif delta > 0.0:
            kr_pref = anp.subtract(kr_pref, mean * delta)

        return cfg, res, kappa, kr_pref, mean
Ejemplo n.º 16
0
    def forward(self, X):
        """
        Actual computation of warping applied to each column of X

        :param X: input data of size (n,dimension)
        """
        warped_X = []
        for col_index, transformation in enumerate(self.transformations):
            x = anp.take(X, [col_index], axis=1)
            warped_X.append(transformation(x))

        return anp.concatenate(warped_X, axis=1)
Ejemplo n.º 17
0
 def write_chunks_to_file(self,
                          this_pos_batch,
                          arr_channel_0,
                          arr_channel_1,
                          probe_size,
                          write_difference=True,
                          dset_2=None):
     dset = self.dset if dset_2 is None else dset_2
     if write_difference:
         if self.monochannel:
             arr_channel_0 = arr_channel_0 - self.arr_0
             arr_channel_0 /= n_ranks
         else:
             arr_channel_0 = arr_channel_0 - np.take(self.arr_0, 0, axis=-1)
             arr_channel_1 = arr_channel_1 - np.take(self.arr_0, 1, axis=-1)
             arr_channel_0 /= n_ranks
             arr_channel_1 /= n_ranks
     write_subblocks_to_file(dset,
                             this_pos_batch,
                             arr_channel_0,
                             arr_channel_1,
                             probe_size,
                             self.full_size,
                             monochannel=self.monochannel)
Ejemplo n.º 18
0
def replicate(x, state_map, axis=-1):
    """
    Replicate an array of shape (..., K) according to the given state map
    to get an array of shape (..., R) where R is the total number of states.

    Parameters
    ----------
    x : array_like, shape (..., K)
        The array to be replicated.

    state_map : array_like, shape (R,), int
        The mapping from [0, K) -> [0, R)
    """
    assert state_map.ndim == 1
    assert np.all(state_map >= 0) and np.all(state_map < x.shape[-1])
    return np.take(x, state_map, axis=axis)
Ejemplo n.º 19
0
def collapse(x, state_map, axis=-1):
    """
    Collapse an array of shape (..., R) to shape (..., K) by summing
    columns that map to the same state in [0, K).

    Parameters
    ----------
    x : array_like, shape (..., R)
        The array to be collapsed.

    state_map : array_like, shape (R,), int
        The mapping from [0, K) -> [0, R)
    """
    R = x.shape[axis]
    assert state_map.ndim == 1 and state_map.shape[0] == R
    K = state_map.max() + 1
    return np.concatenate([np.sum(np.take(x, np.where(state_map == k)[0], axis=axis),
                                  axis=axis, keepdims=True)
                           for k in range(K)], axis=axis)
Ejemplo n.º 20
0
def im2col(img, block_size=(5, 5), skip=1):
    """ stretches block_size size'd patches centered skip distance 
        away in both row/column space, stacks into columns (and stacks)
        bands into rows

        Use-case is for storing images for quick matrix multiplies
           - blows up memory usage by quite a bit (factor of 10!)

        motivated by implementation discussion here: 
            http://cs231n.github.io/convolutional-networks/

        edited from snippet here:
            http://stackoverflow.com/questions/30109068/implement-matlabs-im2col-sliding-in-python
    """
    # stack depth bands (colors)
    if len(img.shape) == 3:
        return np.vstack([
            im2col(img[:, :, k], block_size, skip)
            for k in xrange(img.shape[2])
        ])

    # input array and block size
    A = img
    B = block_size

    # Parameters
    M, N = A.shape
    col_extent = N - B[1] + 1
    row_extent = M - B[0] + 1

    # Get Starting block indices
    start_idx = np.arange(B[0])[:, None] * N + np.arange(B[1])

    # Get offsetted indices across the height and width of input array
    offset_idx = np.arange(0, row_extent, skip)[:, None] * N + np.arange(
        0, col_extent, skip)

    # Get all actual indices & index into input array for final output
    out = np.take(A, start_idx.ravel()[:, None] + offset_idx.ravel())
    return out
Ejemplo n.º 21
0
def im2col(img, block_size = (5, 5), skip = 1):
    """ stretches block_size size'd patches centered skip distance 
        away in both row/column space, stacks into columns (and stacks)
        bands into rows

        Use-case is for storing images for quick matrix multiplies
           - blows up memory usage by quite a bit (factor of 10!)

        motivated by implementation discussion here: 
            http://cs231n.github.io/convolutional-networks/

        edited from snippet here:
            http://stackoverflow.com/questions/30109068/implement-matlabs-im2col-sliding-in-python
    """
    # stack depth bands (colors)
    if len(img.shape) == 3:
        return np.vstack([ im2col(img[:,:,k], block_size, skip)
                           for k in xrange(img.shape[2]) ])

    # input array and block size
    A = img
    B = block_size

    # Parameters
    M,N = A.shape
    col_extent = N - B[1] + 1
    row_extent = M - B[0] + 1

    # Get Starting block indices
    start_idx = np.arange(B[0])[:,None]*N + np.arange(B[1])

    # Get offsetted indices across the height and width of input array
    offset_idx = np.arange(0, row_extent, skip)[:,None]*N + np.arange(0, col_extent, skip)

    # Get all actual indices & index into input array for final output
    out = np.take(A,start_idx.ravel()[:,None] + offset_idx.ravel())
    return out
Ejemplo n.º 22
0
def reconstruct_ptychography(
        # ______________________________________
        # |Raw data and experimental parameters|________________________________
        fname, probe_pos, probe_size, obj_size, theta_st=0, theta_end=PI, n_theta=None, theta_downsample=None,
        energy_ev=5000, psize_cm=1e-7, free_prop_cm=None,
        # ___________________________
        # |Reconstruction parameters|___________________________________________
        n_epochs='auto', crit_conv_rate=0.03, max_nepochs=200, alpha_d=None, alpha_b=None,
        gamma=1e-6, learning_rate=1.0, minibatch_size=None, multiscale_level=1, n_epoch_final_pass=None,
        initial_guess=None, n_batch_per_update=1, reweighted_l1=False, interpolation='bilinear',
        # ___________________________
        # |Finite support constraint|___________________________________________
        finite_support_mask_path=None, shrink_cycle=None, shrink_threshold=1e-9,
        # ___________________
        # |Object contraints|
        object_type='normal',
        # _______________
        # |Forward model|_______________________________________________________
        forward_algorithm='fresnel', binning=1, fresnel_approx=False, pure_projection=False, two_d_mode=False,
        probe_type='gaussian', probe_initial=None,
        # _____
        # |I/O|_________________________________________________________________
        save_path='.', output_folder=None, save_intermediate=False, full_intermediate=False, use_checkpoint=True,
        save_stdout=False,
        # _____________
        # |Performance|_________________________________________________________
        cpu_only=False, core_parallelization=True, shared_file_object=True, n_dp_batch=20,
        # __________________________
        # |Object optimizer options|____________________________________________
        optimizer='adam',
        # _________________________
        # |Other optimizer options|_____________________________________________
        probe_learning_rate=1e-3,
        optimize_probe_defocusing=False, probe_defocusing_learning_rate=1e-5,
        optimize_probe_pos_offset=False,
        # ________________
        # |Other settings|______________________________________________________
        dynamic_rate=True, pupil_function=None, probe_circ_mask=0.9, dynamic_dropping=False, dropping_threshold=8e-5,
        **kwargs,):
        # ______________________________________________________________________

    """
    Notes:
        1. Input data are assumed to be contained in an HDF5 under 'exchange/data', as a 4D dataset of
           shape [n_theta, n_spots, detector_size_y, detector_size_x].
        2. Full-field reconstruction is treated as ptychography. If the image is not divided, the programs
           runs as if it is dealing with ptychography with only 1 spot per angle.
        3. Full-field reconstruction with minibatch_size > 1 but without image dividing is not supported.
           In this case, minibatch_size will be forced to be 1, so that each rank process only one
           rotation angle's image at a time. To perform large fullfield reconstruction efficiently,
           divide the data into sub-chunks.
        4. Full-field reconstruction using shared_file_mode but without image dividing is not recommended
           even if minibatch_size is 1. In shared_file_mode, all ranks process data from the same rotation
           angle in each synchronized batch. Doing this will cause all ranks to process the same data.
           To perform large fullfield reconstruction efficiently, divide the data into sub-chunks.
    """

    def calculate_loss(obj_delta, obj_beta, probe_real, probe_imag, probe_defocus_mm, probe_pos_offset, this_i_theta, this_pos_batch, this_prj_batch):

        if optimize_probe_defocusing:
            h_probe = get_kernel(probe_defocus_mm * 1e6, lmbda_nm, voxel_nm, probe_size, fresnel_approx=fresnel_approx)
            probe_complex = probe_real + 1j * probe_imag
            probe_complex = np.fft.ifft2(np.fft.ifftshift(np.fft.fftshift(np.fft.fft2(probe_complex)) * h_probe))
            probe_real = np.real(probe_complex)
            probe_imag = np.imag(probe_complex)

        if optimize_probe_pos_offset:
            this_pos_batch = this_pos_batch + probe_pos_offset[this_i_theta]
        if not shared_file_object:
            obj_stack = np.stack([obj_delta, obj_beta], axis=3)
            if not two_d_mode:
                obj_rot = apply_rotation(obj_stack, coord_ls[this_i_theta])
                # obj_rot = sp_rotate(obj_stack, theta, axes=(1, 2), reshape=False)
            else:
                obj_rot = obj_stack
            probe_pos_batch_ls = []
            exiting_ls = []
            i_dp = 0
            while i_dp < minibatch_size:
                probe_pos_batch_ls.append(this_pos_batch[i_dp:min([i_dp + n_dp_batch, minibatch_size])])
                i_dp += n_dp_batch

            # Pad if needed
            obj_rot, pad_arr = pad_object(obj_rot, this_obj_size, probe_pos, probe_size)

            for k, pos_batch in enumerate(probe_pos_batch_ls):
                subobj_ls = []
                for j in range(len(pos_batch)):
                    pos = pos_batch[j]
                    pos = [int(x) for x in pos]
                    pos[0] = pos[0] + pad_arr[0, 0]
                    pos[1] = pos[1] + pad_arr[1, 0]
                    subobj = obj_rot[pos[0]:pos[0] + probe_size[0], pos[1]:pos[1] + probe_size[1], :, :]
                    subobj_ls.append(subobj)

                subobj_ls = np.stack(subobj_ls)
                exiting = multislice_propagate_batch_numpy(subobj_ls[:, :, :, :, 0], subobj_ls[:, :, :, :, 1], probe_real,
                                                           probe_imag, energy_ev, psize_cm * ds_level, kernel=h, free_prop_cm=free_prop_cm,
                                                           obj_batch_shape=[len(pos_batch), *probe_size, this_obj_size[-1]],
                                                           fresnel_approx=fresnel_approx, pure_projection=pure_projection)
                exiting_ls.append(exiting)
            exiting_ls = np.concatenate(exiting_ls, 0)
            loss = np.mean((np.abs(exiting_ls) - np.abs(this_prj_batch)) ** 2)

        else:
            probe_pos_batch_ls = []
            exiting_ls = []
            i_dp = 0
            while i_dp < minibatch_size:
                probe_pos_batch_ls.append(this_pos_batch[i_dp:min([i_dp + n_dp_batch, minibatch_size])])
                i_dp += n_dp_batch

            pos_ind = 0
            for k, pos_batch in enumerate(probe_pos_batch_ls):
                subobj_ls_delta = obj_delta[pos_ind:pos_ind + len(pos_batch), :, :, :]
                subobj_ls_beta = obj_beta[pos_ind:pos_ind + len(pos_batch), :, :, :]
                exiting = multislice_propagate_batch_numpy(subobj_ls_delta, subobj_ls_beta, probe_real,
                                                           probe_imag, energy_ev, psize_cm * ds_level, kernel=h,
                                                           free_prop_cm=free_prop_cm,
                                                           obj_batch_shape=[len(pos_batch), *probe_size,
                                                                            this_obj_size[-1]],
                                                           fresnel_approx=fresnel_approx,
                                                           pure_projection=pure_projection)
                exiting_ls.append(exiting)
                pos_ind += len(pos_batch)
            exiting_ls = np.concatenate(exiting_ls, 0)
            loss = np.mean((np.abs(exiting_ls) - np.abs(this_prj_batch)) ** 2)
            # dxchange.write_tiff(abs(exiting_ls._value[0]), output_folder + '/det/det', dtype='float32', overwrite=True)
            # raise

        # Regularization
        if reweighted_l1:
            if alpha_d not in [None, 0]:
                loss = loss + alpha_d * np.mean(weight_l1 * np.abs(obj_delta))
            if alpha_b not in [None, 0]:
                loss = loss + alpha_b * np.mean(weight_l1 * np.abs(obj_beta))
        else:
            if alpha_d not in [None, 0]:
                loss = loss + alpha_d * np.mean(np.abs(obj_delta))
            if alpha_b not in [None, 0]:
                loss = loss + alpha_b * np.mean(np.abs(obj_beta))
        if gamma not in [None, 0]:
            if shared_file_object:
                loss = loss + gamma * total_variation_3d(obj_delta, axis_offset=1)
            else:
                loss = loss + gamma * total_variation_3d(obj_delta, axis_offset=0)

        # Write convergence data
        global current_loss
        current_loss = loss._value
        f_conv.write('{},{},{},'.format(i_epoch, i_batch, current_loss))
        f_conv.flush()

        return loss

    comm = MPI.COMM_WORLD
    n_ranks = comm.Get_size()
    rank = comm.Get_rank()
    t_zero = time.time()

    timestr = str(datetime.datetime.today())
    timestr = timestr[:timestr.find('.')]
    for i in [':', '-', ' ']:
        if i == ' ':
            timestr = timestr.replace(i, '_')
        else:
            timestr = timestr.replace(i, '')

    # ================================================================================
    # Create pointer for raw data.
    # ================================================================================
    t0 = time.time()
    print_flush('Reading data...', 0, rank)
    f = h5py.File(os.path.join(save_path, fname), 'r')
    prj = f['exchange/data']
    if n_theta is None:
        n_theta = prj.shape[0]
    if two_d_mode:
        n_theta = 1
    prj_theta_ind = np.arange(n_theta, dtype=int)
    theta = -np.linspace(theta_st, theta_end, n_theta, dtype='float32')
    if theta_downsample is not None:
        theta = theta[::theta_downsample]
        prj_theta_ind = prj_theta_ind[::theta_downsample]
        n_theta = len(theta)
    original_shape = [n_theta, *prj.shape[1:]]

    print_flush('Data reading: {} s'.format(time.time() - t0), 0, rank)
    print_flush('Data shape: {}'.format(original_shape), 0, rank)
    comm.Barrier()

    not_first_level = False
    stdout_options = {'save_stdout': save_stdout, 'output_folder': output_folder, 
                      'timestamp': timestr}

    n_pos = len(probe_pos)
    probe_pos = np.array(probe_pos)

    # ================================================================================
    # Batching check.
    # ================================================================================
    if minibatch_size > 1 and n_pos == 1:
        warnings.warn('It seems that you are processing undivided fullfield data with'
                      'minibatch > 1. A rank can only process data from the same rotation'
                      'angle at a time. I am setting minibatch_size to 1.')
        minibatch_size = 1
    if shared_file_object and n_pos == 1:
        warnings.warn('It seems that you are processing undivided fullfield data with'
                      'shared_file_object=True. In shared-file mode, all ranks must'
                      'process data from the same rotation angle in each synchronized'
                      'batch.')

    # ================================================================================
    # Set output folder name if not specified.
    # ================================================================================
    if output_folder is None:
        output_folder = 'recon_{}'.format(timestr)
        if abs(PI - theta_end) < 1e-3:
            output_folder += '_180'
    print_flush('Output folder is {}'.format(output_folder), 0, rank)

    if save_path != '.':
        output_folder = os.path.join(save_path, output_folder)

    for ds_level in range(multiscale_level - 1, -1, -1):

        # ================================================================================
        # Set metadata.
        # ================================================================================
        ds_level = 2 ** ds_level
        print_flush('Multiscale downsampling level: {}'.format(ds_level), 0, rank, **stdout_options)
        comm.Barrier()

        prj_shape = original_shape

        if ds_level > 1:
            this_obj_size = [int(x / ds_level) for x in obj_size]
        else:
            this_obj_size = obj_size

        dim_y, dim_x = prj_shape[-2:]
        if minibatch_size is None:
            minibatch_size = n_pos
        comm.Barrier()

        # ================================================================================
        # Create output directory.
        # ================================================================================
        if rank == 0:
            try:
                os.makedirs(os.path.join(output_folder))
            except:
                print('Target folder {} exists.'.format(output_folder))
        comm.Barrier()

        # ================================================================================
        # Create object function optimizer.
        # ================================================================================
        if optimizer == 'adam':
            opt = AdamOptimizer([*this_obj_size, 2], output_folder=output_folder)
            optimizer_options_obj = {'step_size': learning_rate,
                                     'shared_file_object': shared_file_object}
        elif optimizer == 'gd':
            opt = GDOptimizer([*this_obj_size, 2], output_folder=output_folder)
            optimizer_options_obj = {'step_size': learning_rate,
                                     'dynamic_rate': True,
                                     'first_downrate_iteration': 20 * max([ceil(n_pos / (minibatch_size * n_ranks)), 1])}
        if shared_file_object:
            opt.create_file_objects(use_checkpoint=use_checkpoint)
        else:
            if use_checkpoint:
                try:
                    opt.restore_param_arrays_from_checkpoint()
                except:
                    opt.create_param_arrays()
            else:
                opt.create_param_arrays()

        # ================================================================================
        # Read rotation data.
        # ================================================================================
        try:
            coord_ls = read_all_origin_coords('arrsize_{}_{}_{}_ntheta_{}'.format(*this_obj_size, n_theta),
                                              n_theta)
        except:
            if rank == 0:
                print_flush('Saving rotation coordinates...', 0, rank, **stdout_options)
                save_rotation_lookup(this_obj_size, n_theta)
            comm.Barrier()
            coord_ls = read_all_origin_coords('arrsize_{}_{}_{}_ntheta_{}'.format(*this_obj_size, n_theta),
                                              n_theta)

        # ================================================================================
        # Unify random seed for all threads.
        # ================================================================================
        comm.Barrier()
        seed = int(time.time() / 60)
        np.random.seed(seed)
        comm.Barrier()

        # ================================================================================
        # Get checkpointed parameters.
        # ================================================================================
        starting_epoch, starting_batch = (0, 0)
        needs_initialize = False if use_checkpoint else True
        if use_checkpoint and shared_file_object:
            try:
                starting_epoch, starting_batch = restore_checkpoint(output_folder, shared_file_object)
            except:
                needs_initialize = True

        elif use_checkpoint and (not shared_file_object):
            try:
                starting_epoch, starting_batch, obj_delta, obj_beta = restore_checkpoint(output_folder, shared_file_object, opt)
            except:
                needs_initialize = True

        # ================================================================================
        # Create object class.
        # ================================================================================
        obj = ObjectFunction([*this_obj_size, 2], shared_file_object=shared_file_object,
                             output_folder=output_folder, ds_level=ds_level, object_type=object_type)
        if shared_file_object:
            obj.create_file_object(use_checkpoint)
            obj.create_temporary_file_object()
            if needs_initialize:
                obj.initialize_file_object(save_stdout=save_stdout, timestr=timestr,
                                           not_first_level=not_first_level, initial_guess=initial_guess)
        else:
            if needs_initialize:
                obj.initialize_array(save_stdout=save_stdout, timestr=timestr,
                                     not_first_level=not_first_level, initial_guess=initial_guess)
            else:
                obj.delta = obj_delta
                obj.beta = obj_beta

        # ================================================================================
        # Create gradient class.
        # ================================================================================
        gradient = Gradient(obj)
        if shared_file_object:
            gradient.create_file_object()
            gradient.initialize_gradient_file()
        else:
            gradient.initialize_array_with_values(np.zeros(this_obj_size), np.zeros(this_obj_size))

        # ================================================================================
        # If a finite support mask path is specified (common for full-field imaging),
        # create an instance of monochannel mask class. While finite_support_mask_path
        # has to point to a 3D tiff file, the mask will be written as an HDF5 if
        # share_file_mode is True.
        # ================================================================================
        mask = None
        if finite_support_mask_path is not None:
            mask = Mask(this_obj_size, finite_support_mask_path, shared_file_object=shared_file_object,
                        output_folder=output_folder, ds_level=ds_level)
            if shared_file_object:
                mask.create_file_object(use_checkpoint=use_checkpoint)
                mask.initialize_file_object()
            else:
                mask_arr = dxchange.read_tiff(finite_support_mask_path)
                mask.initialize_array_with_values(mask_arr)

        # ================================================================================
        # Initialize probe functions.
        # ================================================================================
        print_flush('Initialzing probe...', 0, rank, **stdout_options)
        probe_real, probe_imag = initialize_probe(probe_size, probe_type, pupil_function=pupil_function, probe_initial=probe_initial,
                             save_stdout=save_stdout, output_folder=output_folder, timestr=timestr,
                             save_path=save_path, fname=fname, **kwargs)

        # ================================================================================
        # generate Fresnel kernel.
        # ================================================================================
        voxel_nm = np.array([psize_cm] * 3) * 1.e7 * ds_level
        lmbda_nm = 1240. / energy_ev
        delta_nm = voxel_nm[-1]
        h = get_kernel(delta_nm * binning, lmbda_nm, voxel_nm, probe_size, fresnel_approx=fresnel_approx)

        # ================================================================================
        # Create other optimizers (probe, probe defocus, probe positions, etc.).
        # ================================================================================
        opt_arg_ls = [0, 1]
        if probe_type == 'optimizable':
            opt_probe = GDOptimizer([*probe_size, 2], output_folder=output_folder)
            optimizer_options_probe = {'step_size': probe_learning_rate,
                                      'dynamic_rate': True,
                                      'first_downrate_iteration': 4 * max([ceil(n_pos / (minibatch_size * n_ranks)), 1])}
            opt_arg_ls = opt_arg_ls + [2, 3]
            opt_probe.set_index_in_grad_return(len(opt_arg_ls))

        probe_defocus_mm = np.array(0.0)
        if optimize_probe_defocusing:
            opt_probe_defocus = GDOptimizer([1], output_folder=output_folder)
            optimizer_options_probe_defocus = {'step_size': probe_defocusing_learning_rate,
                                               'dynamic_rate': True,
                                               'first_downrate_iteration': 4 * max([ceil(n_pos / (minibatch_size * n_ranks)), 1])}
            opt_arg_ls.append(4)
            opt_probe_defocus.set_index_in_grad_return(len(opt_arg_ls))

        probe_pos_offset = np.zeros([n_theta, 2])
        if optimize_probe_pos_offset:
            opt_probe_pos_offset = GDOptimizer(probe_pos_offset.shape, output_folder=output_folder)
            optimizer_options_probe_pos_offset = {'step_size': 0.5,
                                                  'dynamic_rate': False}
            opt_arg_ls.append(5)
            opt_probe_pos_offset.set_index_in_grad_return(len(opt_arg_ls))

        # ================================================================================
        # Get gradient of loss function w.r.t. optimizable variables.
        # ================================================================================
        loss_grad = grad(calculate_loss, opt_arg_ls)

        # ================================================================================
        # Save convergence data.
        # ================================================================================
        if rank == 0:
            try:
                os.makedirs(os.path.join(output_folder, 'convergence'))
            except:
                pass
        comm.Barrier()
        f_conv = open(os.path.join(output_folder, 'convergence', 'loss_rank_{}.txt'.format(rank)), 'w')
        f_conv.write('i_epoch,i_batch,loss,time\n')

        # ================================================================================
        # Create parameter summary file.
        # ================================================================================
        print_flush('Optimizer started.', 0, rank, **stdout_options)
        if rank == 0:
            create_summary(output_folder, locals(), preset='ptycho')

        # ================================================================================
        # Start outer (epoch) loop.
        # ================================================================================
        cont = True
        i_epoch = starting_epoch
        m_p, v_p, m_pd, v_pd = (None, None, None, None)
        while cont:
            n_pos = len(probe_pos)
            n_spots = n_theta * n_pos
            n_tot_per_batch = minibatch_size * n_ranks
            n_batch = int(np.ceil(float(n_spots) / n_tot_per_batch))

            t0 = time.time()
            spots_ls = range(n_spots)
            ind_list_rand = []

            t00 = time.time()
            print_flush('Allocating jobs over threads...', 0, rank, **stdout_options)
            # Make a list of all thetas and spot positions'
            np.random.seed(i_epoch)
            comm.Barrier()
            if not two_d_mode:
                theta_ls = np.arange(n_theta)
                np.random.shuffle(theta_ls)
            else:
                theta_ls = np.linspace(0, 2 * PI, prj.shape[0])
                theta_ls = abs(theta_ls - theta_st) < 1e-5
                i_theta = np.nonzero(theta_ls)[0][0]
                theta_ls = np.array([i_theta])

            # ================================================================================
            # Put diffraction spots from all angles together, and divide into minibatches.
            # ================================================================================
            for i, i_theta in enumerate(theta_ls):
                spots_ls = range(n_pos)
                # ================================================================================
                # Append randomly selected diffraction spots if necessary, so that a rank won't be given
                # spots from different angles in one batch.
                # When using shared file object, we must also ensure that all ranks deal with data at the
                # same angle at a time.
                # ================================================================================
                if not shared_file_object and n_pos % minibatch_size != 0:
                    spots_ls = np.append(spots_ls, np.random.choice(spots_ls,
                                                                    minibatch_size - (n_pos % minibatch_size),
                                                                    replace=False))
                elif shared_file_object and n_pos % n_tot_per_batch != 0:
                    spots_ls = np.append(spots_ls, np.random.choice(spots_ls,
                                                                    n_tot_per_batch - (n_pos % n_tot_per_batch),
                                                                    replace=False))
                # ================================================================================
                # Create task list for the current angle.
                # ind_list_rand is in the format of [((5, 0), (5, 1), ...), ((17, 0), (17, 1), ..., (...))]
                #                                    |___________________|   |_____|
                #                       a batch for all ranks  _|               |_ (i_theta, i_spot)
                #                    (minibatch_size * n_ranks)
                # ================================================================================
                if i == 0:
                    ind_list_rand = np.vstack([np.array([i_theta] * len(spots_ls)), spots_ls]).transpose()
                else:
                    ind_list_rand = np.concatenate(
                        [ind_list_rand, np.vstack([np.array([i_theta] * len(spots_ls)), spots_ls]).transpose()], axis=0)
            ind_list_rand = split_tasks(ind_list_rand, n_tot_per_batch)

            print_flush('Allocation done in {} s.'.format(time.time() - t00), 0, rank, **stdout_options)

            current_i_theta = 0
            for i_batch in range(starting_batch, n_batch):

                # ================================================================================
                # Initialize.
                # ================================================================================
                print_flush('Epoch {}, batch {} of {} started.'.format(i_epoch, i_batch, n_batch), 0, rank, **stdout_options)
                opt.i_batch = 0

                # ================================================================================
                # Save checkpoint.
                # ================================================================================
                if shared_file_object:
                    save_checkpoint(i_epoch, i_batch, output_folder, shared_file_object=True,
                                    obj_array=None, optimizer=opt)
                    obj.f.flush()
                else:
                    save_checkpoint(i_epoch, i_batch, output_folder, shared_file_object=False,
                                    obj_array=np.stack([obj.delta, obj.beta], axis=-1), optimizer=opt)

                # ================================================================================
                # Get scan position, rotation angle indices, and raw data for current batch.
                # ================================================================================
                t00 = time.time()
                if len(ind_list_rand[i_batch]) < n_tot_per_batch:
                    n_supp = n_tot_per_batch - len(ind_list_rand[i_batch])
                    ind_list_rand[i_batch] = np.concatenate([ind_list_rand[i_batch], ind_list_rand[0][:n_supp]])

                this_ind_batch = ind_list_rand[i_batch]
                this_i_theta = this_ind_batch[rank * minibatch_size, 0]
                this_ind_rank = np.sort(this_ind_batch[rank * minibatch_size:(rank + 1) * minibatch_size, 1])
                this_pos_batch = probe_pos[this_ind_rank]
                print_flush('Current rank is processing angle ID {}.'.format(this_i_theta), 0, rank, **stdout_options)

                t_prj_0 = time.time()
                this_prj_batch = prj[this_i_theta, this_ind_rank]
                print_flush('  Raw data reading done in {} s.'.format(time.time() - t_prj_0), 0, rank, **stdout_options)

                # ================================================================================
                # In shared file mode, if moving to a new angle, rotate the HDF5 object and saved
                # the rotated object into the temporary file object.
                # ================================================================================
                if shared_file_object and this_i_theta != current_i_theta:
                    current_i_theta = this_i_theta
                    print_flush('  Rotating dataset...', 0, rank, **stdout_options)
                    t_rot_0 = time.time()
                    obj.rotate_data_in_file(coord_ls[this_i_theta], interpolation=interpolation, dset_2=obj.dset_rot)
                    # opt.rotate_files(coord_ls[this_i_theta], interpolation=interpolation)
                    # if mask is not None: mask.rotate_data_in_file(coord_ls[this_i_theta], interpolation=interpolation)
                    comm.Barrier()
                    print_flush('  Dataset rotation done in {} s.'.format(time.time() - t_rot_0), 0, rank, **stdout_options)

                if ds_level > 1:
                    this_prj_batch = this_prj_batch[:, :, ::ds_level, ::ds_level]
                comm.Barrier()

                if shared_file_object:
                    # ================================================================================
                    # Get values for local chunks of object_delta and beta; interpolate and read directly from HDF5
                    # ================================================================================
                    t_read_0 = time.time()
                    obj_rot = obj.read_chunks_from_file(this_pos_batch, probe_size, dset_2=obj.dset_rot)
                    print_flush('  Chunk reading done in {} s.'.format(time.time() - t_read_0), 0, rank, **stdout_options)
                    obj_delta = np.array(obj_rot[:, :, :, :, 0])
                    obj_beta = np.array(obj_rot[:, :, :, :, 1])
                    opt.get_params_from_file(this_pos_batch, probe_size)
                else:
                    obj_delta = obj.delta
                    obj_beta = obj.beta

                # Update weight for reweighted L1
                if shared_file_object:
                    weight_l1 = np.max(obj_delta) / (abs(obj_delta) + 1e-8)
                else:
                    if i_batch % 10 == 0: weight_l1 = np.max(obj_delta) / (abs(obj_delta) + 1e-8)

                # ================================================================================
                # Calculate object gradients.
                # ================================================================================
                t_grad_0 = time.time()
                grads = loss_grad(obj_delta, obj_beta, probe_real, probe_imag, probe_defocus_mm, probe_pos_offset, this_i_theta, this_pos_batch, this_prj_batch)
                print_flush('  Gradient calculation done in {} s.'.format(time.time() - t_grad_0), 0, rank, **stdout_options)
                grads = list(grads)

                # ================================================================================
                # Reshape object gradient to [y, x, z, c] or [n, y, x, z, c] and average over
                # ranks.
                # ================================================================================
                if shared_file_object:
                    obj_grads = np.stack(grads[:2], axis=-1)
                else:
                    this_obj_grads = np.stack(grads[:2], axis=-1)
                    obj_grads = np.zeros_like(this_obj_grads)
                    comm.Barrier()
                    comm.Allreduce(this_obj_grads, obj_grads)
                obj_grads = obj_grads / n_ranks

                # ================================================================================
                # Update object function with optimizer if not shared_file_object; otherwise,
                # just save the gradient chunk into the gradient file.
                # ================================================================================
                if not shared_file_object:
                    effective_iter = i_batch // max([ceil(n_pos / (minibatch_size * n_ranks)), 1])
                    obj_temp = opt.apply_gradient(np.stack([obj_delta, obj_beta], axis=-1), obj_grads, effective_iter,
                                                            **optimizer_options_obj)
                    obj_delta = np.take(obj_temp, 0, axis=-1)
                    obj_beta = np.take(obj_temp, 1, axis=-1)
                else:
                    t_grad_write_0 = time.time()
                    gradient.write_chunks_to_file(this_pos_batch, np.take(obj_grads, 0, axis=-1),
                                                  np.take(obj_grads, 1, axis=-1), probe_size,
                                                  write_difference=False)
                    print_flush('  Gradient writing done in {} s.'.format(time.time() - t_grad_write_0), 0, rank, **stdout_options)
                # ================================================================================
                # Nonnegativity and phase/absorption-only constraints for non-shared-file-mode,
                # and update arrays in instance.
                # ================================================================================
                if not shared_file_object:
                    obj_delta = np.clip(obj_delta, 0, None)
                    obj_beta = np.clip(obj_beta, 0, None)
                    if object_type == 'absorption_only': obj_delta[...] = 0
                    if object_type == 'phase_only': obj_beta[...] = 0
                    obj.delta = obj_delta
                    obj.beta = obj_beta

                # ================================================================================
                # Optimize probe and other parameters if necessary.
                # ================================================================================
                if probe_type == 'optimizable':
                    this_probe_grads = np.stack(grads[2:4], axis=-1)
                    probe_grads = np.zeros_like(this_probe_grads)
                    comm.Allreduce(this_probe_grads, probe_grads)
                    probe_grads = probe_grads / n_ranks
                    probe_temp = opt_probe.apply_gradient(np.stack([probe_real, probe_imag], axis=-1), probe_grads, **optimizer_options_probe)
                    probe_real = np.take(probe_temp, 0, axis=-1)
                    probe_imag = np.take(probe_temp, 1, axis=-1)

                if optimize_probe_defocusing:
                    this_pd_grad = np.array(grads[opt_probe_defocus.index_in_grad_returns])
                    pd_grads = np.array(0.0)
                    comm.Allreduce(this_pd_grad, pd_grads)
                    pd_grads = pd_grads / n_ranks
                    probe_defocus_mm = opt_probe_defocus.apply_gradient(probe_defocus_mm, pd_grads,
                                                                        **optimizer_options_probe_defocus)
                    print_flush('  Probe defocus is {} mm.'.format(probe_defocus_mm), 0, rank,
                                **stdout_options)

                if optimize_probe_pos_offset:
                    this_pos_offset_grad = np.array(grads[optimize_probe_pos_offset.index_in_grad_returns])
                    pos_offset_grads = np.zeros_like(probe_pos_offset)
                    comm.Allreduce(this_pos_offset_grad, pos_offset_grads)
                    pos_offset_grads = pos_offset_grads / n_ranks
                    probe_pos_offset = opt_probe_pos_offset.apply_gradient(probe_pos_offset, pos_offset_grads,
                                                                        **optimizer_options_probe_pos_offset)

                # ================================================================================
                # For shared-file-mode, if finishing or above to move to a different angle,
                # rotate the gradient back, and use it to update the object at 0 deg. Then
                # update the object using gradient at 0 deg.
                # ================================================================================
                if shared_file_object and (i_batch == n_batch - 1 or ind_list_rand[i_batch + 1][0, 0] != current_i_theta):
                    coord_new = read_origin_coords('arrsize_{}_{}_{}_ntheta_{}'.format(*this_obj_size, n_theta),
                                                   this_i_theta, reverse=True)
                    print_flush('  Rotating gradient dataset back...', 0, rank, **stdout_options)
                    t_rot_0 = time.time()
                    # dxchange.write_tiff(gradient.dset[:, :, :, 0], 'adhesin/test_shared_file/grad_prerot', dtype='float32')
                    # gradient.reverse_rotate_data_in_file(coord_ls[this_i_theta], interpolation=interpolation)
                    gradient.rotate_data_in_file(coord_new, interpolation=interpolation)
                    # dxchange.write_tiff(gradient.dset[:, :, :, 0], 'adhesin/test_shared_file/grad_postrot', dtype='float32')
                    # comm.Barrier()
                    print_flush('  Gradient rotation done in {} s.'.format(time.time() - t_rot_0), 0, rank, **stdout_options)

                    t_apply_grad_0 = time.time()
                    opt.apply_gradient_to_file(obj, gradient, **optimizer_options_obj)
                    print_flush('  Object update done in {} s.'.format(time.time() - t_apply_grad_0), 0, rank, **stdout_options)
                    gradient.initialize_gradient_file()

                # ================================================================================
                # Apply finite support mask if specified.
                # ================================================================================
                if mask is not None:
                    if not shared_file_object:
                        obj.apply_finite_support_mask_to_array(mask)
                    else:
                        obj.apply_finite_support_mask_to_file(mask)
                    print_flush('  Mask applied.', 0, rank, **stdout_options)

                # ================================================================================
                # Update finite support mask if necessary.
                # ================================================================================
                if mask is not None and shrink_cycle is not None:
                    if i_batch % shrink_cycle == 0 and i_batch > 0:
                        if shared_file_object:
                            mask.update_mask_file(obj, shrink_threshold)
                        else:
                            mask.update_mask_array(obj, shrink_threshold)
                        print_flush('  Mask updated.', 0, rank, **stdout_options)

                # ================================================================================
                # Save intermediate object.
                # ================================================================================
                if rank == 0 and save_intermediate:
                    if shared_file_object:
                        dxchange.write_tiff(obj.dset[:, :, :, 0],
                                            fname=os.path.join(output_folder, 'intermediate', 'current'.format(ds_level)),
                                            dtype='float32', overwrite=True)
                    else:
                        dxchange.write_tiff(obj.delta,
                                            fname=os.path.join(output_folder, 'intermediate', 'current'.format(ds_level)),
                                            dtype='float32', overwrite=True)
                comm.Barrier()
                print_flush('Minibatch done in {} s; loss (rank 0) is {}.'.format(time.time() - t00, current_loss), 0, rank, **stdout_options)
                f_conv.write('{}\n'.format(time.time() - t_zero))
                f_conv.flush()

            if n_epochs == 'auto':
                    pass
            else:
                if i_epoch == n_epochs - 1: cont = False

            i_epoch = i_epoch + 1

            average_loss = 0
            print_flush(
                'Epoch {} (rank {}); Delta-t = {} s; current time = {} s,'.format(i_epoch, rank,
                                                                    time.time() - t0, time.time() - t_zero), **stdout_options)
            if rank == 0 and save_intermediate:
                if shared_file_object:
                    dxchange.write_tiff(obj.dset[:, :, :, 0],
                                        fname=os.path.join(output_folder, 'delta_ds_{}'.format(ds_level)),
                                        dtype='float32', overwrite=True)
                    dxchange.write_tiff(obj.dset[:, :, :, 1],
                                        fname=os.path.join(output_folder, 'beta_ds_{}'.format(ds_level)),
                                        dtype='float32', overwrite=True)
                    dxchange.write_tiff(np.sqrt(probe_real ** 2 + probe_imag ** 2),
                                        fname=os.path.join(output_folder, 'probe_mag_ds_{}'.format(ds_level)),
                                        dtype='float32', overwrite=True)
                    dxchange.write_tiff(np.arctan2(probe_imag, probe_real),
                                        fname=os.path.join(output_folder, 'probe_phase_ds_{}'.format(ds_level)),
                                        dtype='float32', overwrite=True)
                else:
                    dxchange.write_tiff(obj.delta,
                                        fname=os.path.join(output_folder, 'delta_ds_{}'.format(ds_level)),
                                        dtype='float32', overwrite=True)
                    dxchange.write_tiff(obj.beta,
                                        fname=os.path.join(output_folder, 'beta_ds_{}'.format(ds_level)),
                                        dtype='float32', overwrite=True)
                    dxchange.write_tiff(np.sqrt(probe_real ** 2 + probe_imag ** 2),
                                        fname=os.path.join(output_folder, 'probe_mag_ds_{}'.format(ds_level)),
                                        dtype='float32', overwrite=True)
                    dxchange.write_tiff(np.arctan2(probe_imag, probe_real),
                                        fname=os.path.join(output_folder, 'probe_phase_ds_{}'.format(ds_level)),
                                        dtype='float32', overwrite=True)
            print_flush('Current iteration finished.', 0, rank, **stdout_options)
        comm.Barrier()