Пример #1
0
def load_training_data(filename: str):
    if not os_path.exists(filename):
        sys_exit(f'{filename} not found')
    f = h5py_File(filename, "r")
    features_dset_train = f["data"]
    f.close()
    return features_dset_train
Пример #2
0
def get_mat_obj_from_h5py(mat_fpath):
    obj = {}
    with h5py_File(mat_fpath, 'r') as f:
        for k, v in f.items():
            obj[k] = np_array(v)

    return obj
Пример #3
0
def main():
    # some parameters
    int_radius = 5
    gain = args.gain
    # data is stored in 39 h5py_Files
    resmin = args.reshigh  # high res cutoff
    resmax = args.reslow  # low res cutoff
    fnames = glob(args.glob)
    if rank == 0:
        print("CMD looked like:")
        print(" ".join(sys.argv))

    # NOTE: for reference, inside each h5 file there is
    #   [u'Amatrices', u'Hi', u'bboxes', u'h5_path']

    # get the total number of shots using worker 0
    if rank == 0:
        print("I am root. I am calculating total number of shots")
        h5s = [h5py_File(f, "r") for f in fnames]
        Nshots_per_file = [h["h5_path"].shape[0] for h in h5s]
        Nshots_tot = sum(Nshots_per_file)
        print("I am root. Total number of shots is %d" % Nshots_tot)

        print("I am root. I will divide shots amongst workers.")
        shot_tuples = []
        for i_f, fname in enumerate(fnames):
            fidx_shotidx = [(i_f, i_shot)
                            for i_shot in range(Nshots_per_file[i_f])]
            shot_tuples += fidx_shotidx

        from numpy import array_split
        print("I am root. Number of uniques = %d" % len(set(shot_tuples)))
        shots_for_rank = array_split(shot_tuples, size)

        # close the open h5s..
        for h in h5s:
            h.close()

    else:
        Nshots_tot = None
        shots_for_rank = None
        h5s = None

    # Nshots_tot = comm.bcast( Nshots_tot, root=0)
    if has_mpi:
        shots_for_rank = comm.bcast(shots_for_rank, root=0)
    # h5s = comm.bcast( h5s, root=0)  # pull in the open hdf5 files

    my_shots = shots_for_rank[rank]

    # open the unique filenames for this rank
    # TODO: check max allowed pointers to open hdf5 file
    my_unique_fids = set([fidx for fidx, _ in my_shots])
    my_open_files = {
        fidx: h5py_File(fnames[fidx], "r")
        for fidx in my_unique_fids
    }

    all_num_kept = 0
    all_num_below = 0
    Ntot = 0
    all_kept_bbox = []
    all_is_kept_flags = []
    for img_num, (fname_idx, shot_idx) in enumerate(my_shots):

        h = my_open_files[fname_idx]

        # load the dxtbx image data directly:
        img_path = h["h5_path"][shot_idx]
        try:
            img_path = img_path.decode()
        except AttributeError:
            pass
        img_data = numpy_load(img_path)["img"]
        bboxes = h["bboxes"]["shot%d" % shot_idx][()]
        panel_ids = h["panel_ids"]["shot%d" % shot_idx][()]
        hi, ki, li = h["Hi"]["shot%d" % shot_idx][()].T
        nspots = len(bboxes)

        if "below_zero" in list(h.keys()):
            # tilt dips below zero
            below_zero = h["below_zero"]["shot%d" % shot_idx][()]
        else:
            below_zero = None

        # use the known cell to compute the resolution of the spots
        #FIXME get the hard coded unit cell outta here!

        if args.resoinfile:
            reso = h["resolution"]["shot%d" % shot_idx][()]
        else:
            if args.p9:
                reso = 1 / sqrt((hi**2 + ki**2) / 114 / 114 +
                                li**2 / 32.5 / 32.5)
            else:
                # TODO: why does 0,0,0 ever appear as a reflection ? Should never happen...
                reso = 1 / sqrt((hi**2 + ki**2) / 79.1 / 79.1 +
                                li**2 / 38.4 / 38.4)

        in_reso_ring = array([resmin <= d < resmax for d in reso])

        # Dirty integrater, sets integration region as disk of diameter 2*int_radius pixels
        if len(img_data.shape) == 2:  # single panel image
            assert len(set(panel_ids)) == 1  # sanity check
            img_data = [img_data]

        is_a_keeper = [in_reso_ring[i_spot] for i_spot in range(nspots)]

        print("In reso: %d" % sum(is_a_keeper))

        hgroups = h.keys()

        if args.snrmin is not None:
            if "SNR_Leslie99" in hgroups:
                SNR = h["SNR_Leslie99"]["shot%d" % shot_idx][()]
            else:
                if rank == 0:
                    print("WARNING USING DIRTY SNR ESTIMATE!")
                dirties = {
                    pid: Integrator(img_data[pid],
                                    int_radius=int_radius,
                                    gain=gain)
                    for pid in set(panel_ids)
                }

                int_data = [
                    dirties[pid].integrate_bbox_dirty(bb)
                    for pid, bb in zip(panel_ids, bboxes)
                ]

                # signal, background, variance  # these are from the paper Leslie '99
                s, b, var = map(array, zip(*int_data))
                SNR = s / sqrt(var)
            is_a_keeper = [
                k and snr >= args.snrmin for k, snr in zip(is_a_keeper, SNR)
            ]

        print("In reso and SNR: %d" % sum(is_a_keeper))

        if "tilt_rms" in hgroups:
            if args.tiltfilt is not None:
                tilt_rms = h["tilt_rms"]["shot%d" % shot_idx][()]
                is_a_keeper = [
                    k and rms < args.tiltfilt
                    for k, rms in zip(is_a_keeper, tilt_rms)
                ]

        if "tilt_error" in hgroups:
            if args.tilterrmax is not None:
                tilt_err = h["tilt_error"]["shot%d" % shot_idx][()]
                is_a_keeper = [
                    k and err <= args.tilterrmax
                    for k, err in zip(is_a_keeper, tilt_err)
                ]
            elif args.tilterrperc is not None:
                assert 0 < args.tilterrperc < 100
                tilt_err = h["tilt_error"]["shot%d" % shot_idx][()]
                val = percentile(tilt_err, args.tilterrperc)
                is_a_keeper = [
                    k and err < val for k, err in zip(is_a_keeper, tilt_err)
                ]
        else:
            if rank == 0:
                print("WARNING: tilt_error not in hdf5 file")

        if "indexed_flag" in hgroups:
            #TODO change me to assume indexed_flag is a bool
            if not args.notindexed:
                indexed_flag = h["indexed_flag"]["shot%d" % shot_idx][()]
                is_a_keeper = [
                    k and (idx > 0)
                    for k, idx in zip(is_a_keeper, indexed_flag)
                ]
        else:
            if rank == 0:
                print("WARNING: indexed_flag not in hdf5 file")

        if "is_on_boundary" in hgroups:
            if not args.onboundary:
                on_boundary = h["is_on_boundary"]["shot%d" % shot_idx][()]
                is_a_keeper = [
                    k and not onbound
                    for k, onbound in zip(is_a_keeper, on_boundary)
                ]
        else:
            if rank == 0:
                print("WARNING: is_on_boundary not in hdf5 file")

        if args.keeperstride is not None:
            from numpy import where
            where_kept = where(is_a_keeper)[0]
            keeper_pos = where_kept[::args.keeperstride]
            for w in where_kept:
                if w not in keeper_pos:
                    is_a_keeper[w] = False

        nkept = np_sum(is_a_keeper)
        if rank == 0:
            print("Keeping %d out of %d spots" % (nkept, nspots))
        if below_zero is not None:
            num_below = np_sum(below_zero[is_a_keeper])
            all_num_below += num_below
            all_num_kept += nkept

        if rank == 0 and args.plot is not None:
            for pid in set(panel_ids):
                plt.gcf().clear()
                plt.imshow(img_data[pid], vmax=250, cmap='viridis')
                for i_spot in range(nspots):
                    if not is_a_keeper[i_spot]:
                        continue
                    if not panel_ids[i_spot] == pid:
                        continue
                    x1, x2, y1, y2 = bboxes[i_spot]
                    patch = plt.Rectangle(xy=(x1, y1),
                                          width=x2 - x1,
                                          height=y2 - y1,
                                          fc='none',
                                          ec='r')
                    plt.gca().add_patch(patch)
                plt.title("Panel=%d" % pid)
                if args.plot == -1:
                    plt.show()
                else:
                    plt.draw()
                    plt.pause(args.plot)

        kept_bboxes = [
            bboxes[i_bb] for i_bb in range(len(bboxes)) if is_a_keeper[i_bb]
        ]

        tot_pix = [(j2 - j1) * (i2 - i1)
                   for i_bb, (i1, i2, j1, j2) in enumerate(kept_bboxes)]
        Ntot += sum(tot_pix)
        if rank == 0:
            print("%g total pixels (file %d / %d)" %
                  (Ntot, img_num + 1, len(my_shots)))
        all_kept_bbox += map(list, kept_bboxes)
        all_is_kept_flags += [(fname_idx, shot_idx, is_a_keeper)
                              ]  # store this information, write to disk

    # close the open hdf5 files so we can write to them again
    for h in my_open_files.values():
        h.close()

    print("END OF LOOP")
    print("Rank %d; total bboxes=%d; Total pixels=%g" %
          (rank, len(all_kept_bbox), Ntot))
    all_kept_bbox = MPI.COMM_WORLD.gather(all_kept_bbox, root=0)
    all_is_kept_flags = MPI.COMM_WORLD.gather(all_is_kept_flags, root=0)

    all_kept = comm.reduce(all_num_kept)
    all_below = comm.reduce(all_num_below)
    if rank == 0:
        print("TOTAL BBOXES KEPT = %d; TOTAL BBOXES BELOW ZERO=%d" %
              (all_kept, all_below))
    if rank == 0:
        all_kept_bbox = [
            bbox for bbox_lst in all_kept_bbox for bbox in bbox_lst
        ]
        Ntot_pix = sum([(j2 - j1) * (i2 - i1)
                        for i1, i2, j1, j2 in all_kept_bbox])
        print("\n<><><><><><><<><><><><><><><><><><><><><><>")
        print("I am root. total bboxes=%d, Total pixels=%g" %
              (len(all_kept_bbox), Ntot_pix))
        print("<><><><><><><<><><><><><><><><><><><><><><>\n")

        print("I am root. I will store flags for each bbox on each shot")

        all_flag_info = [i for sl in all_is_kept_flags for i in sl]  # flatten

        # open the hdf5 files in read+write mode and store the bbox keeper flags
        h5s = {i_f: h5py_File(f, "r+") for i_f, f in enumerate(fnames)}

        for i_info, (fidx, shot_idx, keeper_flags) in enumerate(all_flag_info):
            bbox_grp = h5s[fidx]["bboxes"]

            flag_name = "%s%d" % (args.keeperstag, shot_idx)

            if flag_name in bbox_grp:
                del bbox_grp[flag_name]

            bbox_grp.create_dataset(flag_name,
                                    data=keeper_flags,
                                    dtype=bool,
                                    compression='lzf')

            if i_info % 5 == 0:
                print("I am root. I saved bbox selection flags ( %d / %d ) " %
                      (i_info + 1, len(all_flag_info)))

        # close the open files..
        for h in h5s.values():
            h.close()
Пример #4
0
    def load(self):

        # some parameters

        # NOTE: for reference, inside each h5 file there is
        #   [u'Amatrices', u'Hi', u'bboxes', u'h5_path']

        # get the total number of shots using worker 0
        if rank == 0:
            print("I am root. I am calculating total number of shots")
            h5s = [h5py_File(f, "r") for f in self.fnames]
            Nshots_per_file = [h["h5_path"].shape[0] for h in h5s]
            Nshots_tot = sum(Nshots_per_file)
            print("I am root. Total number of shots is %d" % Nshots_tot)

            print("I am root. I will divide shots amongst workers.")
            shot_tuples = []
            for i_f, fname in enumerate(self.fnames):
                fidx_shotidx = [(i_f, i_shot)
                                for i_shot in range(Nshots_per_file[i_f])]
                shot_tuples += fidx_shotidx

            from numpy import array_split
            print("I am root. Number of uniques = %d" % len(set(shot_tuples)))
            shots_for_rank = array_split(shot_tuples, size)

            # close the open h5s..
            for h in h5s:
                h.close()

        else:
            Nshots_tot = None
            shots_for_rank = None
            h5s = None

        #Nshots_tot = comm.bcast( Nshots_tot, root=0)
        if has_mpi:
            shots_for_rank = comm.bcast(shots_for_rank, root=0)
        #h5s = comm.bcast( h5s, root=0)  # pull in the open hdf5 files

        my_shots = shots_for_rank[rank]
        if self.Nload is not None:
            my_shots = my_shots[:self.Nload]

        # open the unique filenames for this rank
        # TODO: check max allowed pointers to open hdf5 file
        my_unique_fids = set([fidx for fidx, _ in my_shots])
        my_open_files = {
            fidx: h5py_File(self.fnames[fidx], "r")
            for fidx in my_unique_fids
        }
        Ntot = 0
        self.all_bbox_pixels = []
        for img_num, (fname_idx, shot_idx) in enumerate(my_shots):
            if img_num == args.Nmax:
                #print("Already processed maximum number images!")
                continue
            h = my_open_files[fname_idx]

            # load the dxtbx image data directly:
            npz_path = h["h5_path"][shot_idx]
            # NOTE take me out!
            if args.testmode:
                import os
                npz_path = os.path.basename(npz_path)
            img_handle = numpy_load(npz_path)
            img = img_handle["img"]

            if len(img.shape) == 2:  # if single panel>>
                img = np.array([img])

            #D = det_from_dict(img_handle["det"][()])
            B = beam_from_dict(img_handle["beam"][()])

            # get the indexed crystal Amatrix
            Amat = h["Amatrices"][shot_idx]
            amat_elems = list(sqr(Amat).inverse().elems)
            # real space basis vectors:
            a_real = amat_elems[:3]
            b_real = amat_elems[3:6]
            c_real = amat_elems[6:]

            # dxtbx indexed crystal model
            C = Crystal(a_real, b_real, c_real, "P43212")

            # change basis here ? Or maybe just average a/b
            a, b, c, _, _, _ = C.get_unit_cell().parameters()
            a_init = .5 * (a + b)
            c_init = c

            # shoe boxes where we expect spots
            bbox_dset = h["bboxes"]["shot%d" % shot_idx]
            n_bboxes_total = bbox_dset.shape[0]
            # is the shoe box within the resolution ring and does it have significant SNR (see filter_bboxes.py)
            is_a_keeper = h["bboxes"]["keepers%d" % shot_idx][()]

            # tilt plane to the background pixels in the shoe boxes
            tilt_abc_dset = h["tilt_abc"]["shot%d" % shot_idx]
            try:
                panel_ids_dset = h["panel_ids"]["shot%d" % shot_idx]
                has_panels = True
            except KeyError:
                has_panels = False

            # apply the filters:
            bboxes = [
                bbox_dset[i_bb] for i_bb in range(n_bboxes_total)
                if is_a_keeper[i_bb]
            ]
            tilt_abc = [
                tilt_abc_dset[i_bb] for i_bb in range(n_bboxes_total)
                if is_a_keeper[i_bb]
            ]
            if has_panels:
                panel_ids = [
                    panel_ids_dset[i_bb] for i_bb in range(n_bboxes_total)
                    if is_a_keeper[i_bb]
                ]
            else:
                panel_ids = [0] * len(tilt_abc)

            # how many pixels do we have
            tot_pix = [(j2 - j1) * (i2 - i1) for i1, i2, j1, j2 in bboxes]
            Ntot += sum(tot_pix)

            # actually load the pixels...
            #data_boxes = [ img[j1:j2, i1:i2] for i1,i2,j1,j2 in bboxes]

            # Here we will try a per-shot refinement of the unit cell and Umatrix, as well as ncells abc
            # and spot scale etc..

            # load some ground truth data from the simulation dumps (e.g. spectrum)
            h5_fname = h["h5_path"][shot_idx].replace(".npz", "")
            # NOTE remove me
            if args.testmode:
                h5_fname = os.path.basename(h5_fname)
            data = h5py_File(h5_fname, "r")

            tru = sqr(data["crystalA"][()]).inverse().elems
            a_tru = tru[:3]
            b_tru = tru[3:6]
            c_tru = tru[6:]
            C_tru = Crystal(a_tru, b_tru, c_tru, "P43212")
            try:
                angular_offset_init = compare_with_ground_truth(
                    a_tru, b_tru, c_tru, [C], symbol="P43212")[0]
            except Exception as err:
                print(
                    "Rank %d: Boo cant use the comparison w GT function: %s" %
                    (rank, err))

            fluxes = data["spectrum"][()]
            es = data["exposure_s"][()]
            fluxes *= es  # multiply by the exposure time
            spectrum = zip(wavelens, fluxes)
            # dont simulate when there are no photons!
            spectrum = [(wave, flux) for wave, flux in spectrum
                        if flux > self.flux_min]

            # make a unit cell manager that the refiner will use to track the B-matrix
            aa, _, cc, _, _, _ = C_tru.get_unit_cell().parameters()
            ucell_man = TetragonalManager(a=a_init, c=c_init)
            if args.startwithtruth:
                ucell_man = TetragonalManager(a=aa, c=cc)

            # create the sim_data instance that the refiner will use to run diffBragg
            # create a nanoBragg crystal
            nbcryst = nanoBragg_crystal()
            nbcryst.dxtbx_crystal = C
            if args.startwithtruth:
                nbcryst.dxtbx_crystal = C_tru

            nbcryst.thick_mm = 0.1
            nbcryst.Ncells_abc = 30, 30, 30
            nbcryst.miller_array = Fhkl_guess.as_amplitude_array()
            nbcryst.n_mos_domains = 1
            nbcryst.mos_spread_deg = 0.0

            # create a nanoBragg beam
            nbbeam = nanoBragg_beam()
            nbbeam.size_mm = 0.001
            nbbeam.unit_s0 = B.get_unit_s0()
            nbbeam.spectrum = spectrum

            # sim data instance
            SIM = SimData()
            SIM.detector = CSPAD
            #SIM.detector = D
            SIM.crystal = nbcryst
            SIM.beam = nbbeam
            SIM.panel_id = 0  # default

            spot_scale = 12
            if args.sad:
                spot_scale = 1
            SIM.instantiate_diffBragg(default_F=0, oversample=0)
            SIM.D.spot_scale = spot_scale

            img_in_photons = img / self.gain

            print("Rank %d, Starting refinement!" % rank)
            try:
                RUC = RefineAllMultiPanel(
                    spot_rois=bboxes,
                    abc_init=tilt_abc,
                    img=img_in_photons,  # NOTE this is now a multi panel image
                    SimData_instance=SIM,
                    plot_images=args.plot,
                    plot_residuals=args.residual,
                    ucell_manager=ucell_man)

                RUC.panel_ids = panel_ids
                RUC.multi_panel = True
                RUC.split_evaluation = args.split
                RUC.trad_conv = True
                RUC.refine_detdist = False
                RUC.refine_background_planes = False
                RUC.refine_Umatrix = True
                RUC.refine_Bmatrix = True
                RUC.refine_ncells = True
                RUC.use_curvatures = False  # args.curvatures
                RUC.calc_curvatures = True  #args.curvatures
                RUC.refine_crystal_scale = True
                RUC.refine_gain_fac = False
                RUC.plot_stride = args.stride
                RUC.poisson_only = False
                RUC.trad_conv_eps = 5e-3  # NOTE this is for single panel model
                RUC.max_calls = 300
                RUC.verbose = False
                RUC.use_rot_priors = True
                RUC.use_ucell_priors = True
                if args.verbose:
                    if rank == 0:  # only show refinement stats for rank 0
                        RUC.verbose = True
                RUC.run()
                if RUC.hit_break_to_use_curvatures:
                    RUC.use_curvatures = True
                    RUC.run(setup=False)
            except AssertionError as err:
                print(
                    "Rank %d, filename %s Hit assertion error during refinement: %s"
                    % (rank, data.filename, err))
                continue

            angle, ax = RUC.get_correction_misset(as_axis_angle_deg=True)
            if args.startwithtruth:
                C = Crystal(a_tru, b_tru, c_tru, "P43212")
            C.rotate_around_origin(ax, angle)
            C.set_B(RUC.get_refined_Bmatrix())
            a_ref, _, c_ref, _, _, _ = C.get_unit_cell().parameters()
            # compute missorientation with ground truth model
            try:
                angular_offset = compare_with_ground_truth(a_tru,
                                                           b_tru,
                                                           c_tru, [C],
                                                           symbol="P43212")[0]
                print(
                    "Rank %d, filename=%s, ang=%f, init_ang=%f, a=%f, init_a=%f, c=%f, init_c=%f"
                    % (rank, data.filename, angular_offset,
                       angular_offset_init, a_ref, a_init, c_ref, c_init))
            except Exception as err:
                print("Rank %d, filename=%s, error %s" %
                      (rank, data.filename, err))

            # free the memory from diffBragg instance
            RUC.S.D.free_all()
            del img  # not sure if needed here..
            del img_in_photons

            if args.testmode:
                exit()

            # peak at the memory usage of this rank
            mem = getrusage(RUSAGE_SELF).ru_maxrss  # peak mem usage in KB
            mem = mem / 1e6  # convert to GB

            if rank == 0:
                print "RANK 0: %.2g total pixels in %d/%d bboxes (file %d / %d); MemUsg=%2.2g GB" \
                      % (Ntot, len(bboxes), n_bboxes_total,  img_num+1, len(my_shots), mem)

            # TODO: accumulate all pixels
            #self.all_bbox_pixels += data_boxes

        for h in my_open_files.values():
            h.close()

        print("Rank %d; all subimages loaded!" % rank)
Пример #5
0
def features_encoding (df, flag):
    '''Creates a HDF5 file containing
       all the features
       Rnx features are encoded in fingerprints'''
    no_of_rxns = 10
    fp_len = 4096
    rxn_len = fp_len + 2
    pathway_len = 3
    y_len = 1

    if flag == "train":
        sys_exit('Encoding feature for training data not available file data_train.h5 must be present in models folder')
    # elif flag == "predict":

    print("Encodining features for the Test set......")
    # temp_f_name = str(uuid4())
    f = h5py_File(
        NamedTemporaryFile(delete=True),
        'w'
    )
    number = rxn_len * no_of_rxns + pathway_len + y_len
    dset = f.create_dataset(
        'data',
        (0, number),
        dtype='i2',
        maxshape=(None, number),
        compression='gzip'
    )

    for row in tqdm(range(len(df))):
        pathway_rxns = np.array([]).reshape(0, rxn_len * no_of_rxns)
        rxns_list = []
        for rxn_no_ in range(no_of_rxns):
            
            rxn_smiles_index = rxn_no_ * 3
            rxn_dg_index = (rxn_no_ + 1)* 3 -2
            rxn_rule_score_index = (rxn_no_ + 1)* 3 - 1
        
            if  str(df.iloc[row , rxn_smiles_index]) != '0':
                #print(df.iloc[row , rxn_smiles_index])
                rxn_smiles = df.iloc[row , rxn_smiles_index]
                rxn_smiles_list = rxn_smiles.split(">>")
                #print(len(rxn_smiles_list))

                if len(rxn_smiles_list) == 2:

                    sub_smiles = rxn_smiles_list[0]
                    sub_m= Chem.MolFromSmiles(sub_smiles)
                    #print(m)
                    sub_fp = AllChem.GetMorganFingerprintAsBitVect(sub_m, 2, nBits = 2048)
                    sub_arr = np.array([])
                    DataStructs.ConvertToNumpyArray(sub_fp, sub_arr)
                    sub_fp= sub_arr.reshape(1,-1)

                    pro_smiles = rxn_smiles_list[1]
                    pro_m= Chem.MolFromSmiles(pro_smiles)
                    #print(m)
                    pro_fp = AllChem.GetMorganFingerprintAsBitVect(pro_m, 2, nBits = 2048)
                    pro_arr = np.zeros((1,))
                    DataStructs.ConvertToNumpyArray(pro_fp, pro_arr)
                    pro_fp= pro_arr.reshape(1,-1)
                    rxn_fp = np.concatenate([sub_fp , pro_fp]).reshape(1, -1)

                elif len(rxn_smiles_list) < 2:
                    
                    pro_smiles = rxn_smiles_list[0]
                    #print(pro_smiles)
                    pro_m= Chem.MolFromSmiles(pro_smiles)
                    #print(pro_m)
                    pro_fp = AllChem.GetMorganFingerprintAsBitVect(pro_m, 2, nBits = fp_len) # JLF: not good !!
                    pro_arr = np.zeros((1,))
                    DataStructs.ConvertToNumpyArray(pro_fp, pro_arr)
                    rxn_fp= pro_arr.reshape(1,-1)
                else:
                    print("There is a problem with the number of components in the reaction")

            else:
                rxn_fp = np.zeros(fp_len).reshape(1,-1)

            rxn_dg = df.iloc[row , rxn_dg_index].reshape(1,-1)
            rxn_rule_score = df.iloc[row , rxn_rule_score_index].reshape(1,-1)
            rxns_list.extend([rxn_fp, rxn_dg, rxn_rule_score])
            #print(rxn_rule_score)

        pathway_rxns = np.concatenate(rxns_list , axis = 1).reshape(1,-1)
        pathway_dg = df.loc[row, "Pathway_Delta_G"].reshape(1,-1)
        pathway_flux = df.loc[row, "Pathway_Flux"].reshape(1,-1)
        pathway_score = df.loc[row, "Pathway_Score"].reshape(1,-1)
        pathway_y = df.loc[row, "Round1_OR"].reshape(1,-1)
        feature = np.concatenate((pathway_rxns, pathway_dg, pathway_flux, pathway_score, pathway_y), axis =1)
        dset.resize(dset.shape[0]+feature.shape[0], axis=0)
        dset[-feature.shape[0]:]= feature
        #print(pathway_flux)

    return dset
Пример #6
0
def main():
    # some parameters
    int_radius = 5
    gain = args.gain
    # data is stored in 39 h5py_Files
    resmin = args.reshigh  # high res cutoff
    resmax = args.reslow  # low res cutoff
    fnames = glob(args.glob)

    # NOTE: for reference, inside each h5 file there is
    #   [u'Amatrices', u'Hi', u'bboxes', u'h5_path']

    # get the total number of shots using worker 0
    if rank == 0:
        print("I am root. I am calculating total number of shots")
        h5s = [h5py_File(f, "r") for f in fnames]
        Nshots_per_file = [h["h5_path"].shape[0] for h in h5s]
        Nshots_tot = sum(Nshots_per_file)
        print("I am root. Total number of shots is %d" % Nshots_tot)

        print("I am root. I will divide shots amongst workers.")
        shot_tuples = []
        for i_f, fname in enumerate(fnames):
            fidx_shotidx = [(i_f, i_shot)
                            for i_shot in range(Nshots_per_file[i_f])]
            shot_tuples += fidx_shotidx

        from numpy import array_split
        print("I am root. Number of uniques = %d" % len(set(shot_tuples)))
        shots_for_rank = array_split(shot_tuples, size)

        # close the open h5s..
        for h in h5s:
            h.close()

    else:
        Nshots_tot = None
        shots_for_rank = None
        h5s = None

    # Nshots_tot = comm.bcast( Nshots_tot, root=0)
    if has_mpi:
        shots_for_rank = comm.bcast(shots_for_rank, root=0)
    # h5s = comm.bcast( h5s, root=0)  # pull in the open hdf5 files

    my_shots = shots_for_rank[rank]

    # open the unique filenames for this rank
    # TODO: check max allowed pointers to open hdf5 file
    my_unique_fids = set([fidx for fidx, _ in my_shots])
    my_open_files = {
        fidx: h5py_File(fnames[fidx], "r")
        for fidx in my_unique_fids
    }

    Ntot = 0
    all_kept_bbox = []
    all_is_kept_flags = []
    for img_num, (fname_idx, shot_idx) in enumerate(my_shots):
        #import numpy as np
        #idx = np.loadtxt("idx_list.txt")
        #f = open("good_alist2.txt", "w")
        #for i in idx:
        #    fi, si= my_shots[int(i)]
        #    h = my_open_files[fi]
        #    img_path = h["h5_path"][si]
        #    print >> f, img_path
        #f.close()
        #exit()

        h = my_open_files[fname_idx]

        # load the dxtbx image data directly:
        # NOTE: h5_path is really the image file path
        img_path = h["h5_path"][shot_idx]
        if six.PY3:
            img_path = img_path.decode("utf-8")
        loader = dxtbx.load(img_path)
        raw_data = loader.get_raw_data()
        if isinstance(raw_data, tuple):
            img_data = array([p.as_numpy_array() for p in raw_data])
        else:
            img_data = loader.get_raw_data().as_numpy_array()

        bboxes = h["bboxes"]["shot%d" % shot_idx][()]
        panel_ids = h["panel_ids"]["shot%d" % shot_idx][()]
        nspots = len(bboxes)

        # use the known cell to compute the resolution of the spots
        reso = h["resolution"]["shot%d" % shot_idx][()]

        in_reso_ring = array([resmin < d < resmax for d in reso])

        # Dirty integrater, sets integration region as disk of diameter 2*int_radius pixels
        if len(img_data.shape) == 2:  # single panel image
            assert len(set(panel_ids)) == 1  # sanity check
            img_data = [img_data]

        is_a_keeper = [in_reso_ring[i_spot] for i_spot in range(nspots)]

        hgroups = h.keys()

        if args.snrmin is not None:
            if "SNR_Leslie99" in hgroups:
                SNR = h["SNR_Leslie99"]["shot%d" % shot_idx][()]
            else:
                if rank == 0:
                    print("WARNING USING DIRTY SNR ESTIMATE!")
                dirties = {
                    pid: Integrator(img_data[pid],
                                    int_radius=int_radius,
                                    gain=gain)
                    for pid in set(panel_ids)
                }

                int_data = [
                    dirties[pid].integrate_bbox_dirty(bb)
                    for pid, bb in zip(panel_ids, bboxes)
                ]

                # signal, background, variance  # these are from the paper Leslie '99
                s, b, var = map(array, zip(*int_data))
                SNR = s / sqrt(var)
            is_a_keeper = [
                k and snr >= args.snrmin for k, snr in zip(is_a_keeper, SNR)
            ]

        if "tilt_rms" in hgroups:
            if args.tiltfilt is not None:
                tilt_rms = h["tilt_rms"]["shot%d" % shot_idx][()]
                is_a_keeper = [
                    k and rms < args.tiltfilt
                    for k, rms in zip(is_a_keeper, tilt_rms)
                ]

        if "tilt_error" in hgroups:
            if args.tilterrmax is not None:
                tilt_err = h["tilt_error"]["shot%d" % shot_idx][()]
                is_a_keeper = [
                    k and err <= args.tilterrmax
                    for k, err in zip(is_a_keeper, tilt_err)
                ]
        else:
            if rank == 0:
                print("WARNING: tilt_error not in hdf5 file")

        if "indexed_flag" in hgroups:
            #TODO change me to assume indexed_flag is a bool
            if not args.notindexed:
                indexed_flag = h["indexed_flag"]["shot%d" % shot_idx][()]
                is_a_keeper = [
                    k and (idx > 0)
                    for k, idx in zip(is_a_keeper, indexed_flag)
                ]
        else:
            if rank == 0:
                print("WARNING: indexed_flag not in hdf5 file")

        if "is_on_boundary" in hgroups:
            if not args.onboundary:
                on_boundary = h["is_on_boundary"]["shot%d" % shot_idx][()]
                is_a_keeper = [
                    k and not onbound
                    for k, onbound in zip(is_a_keeper, on_boundary)
                ]
        else:
            if rank == 0:
                print("WARNING: is_on_boundary not in hdf5 file")

        if rank == 0:
            print("Keeping %d out of %d spots" % (sum(is_a_keeper), nspots))

        if rank == 0 and args.plot is not None:
            for pid in set(panel_ids):
                plt.gcf().clear()
                import numpy as np
                m = np.median(img_data[pid])
                s = np.std(img_data[pid][img_data[pid] > 10])
                vmin = m - s
                vmax = m + 5 * s
                plt.imshow(img_data[pid], vmax=vmax, vmin=vmin, cmap='viridis')
                for i_spot in range(nspots):
                    if not is_a_keeper[i_spot]:
                        continue
                    if not panel_ids[i_spot] == pid:
                        continue
                    x1, x2, y1, y2 = bboxes[i_spot]
                    patch = plt.Rectangle(xy=(x1, y1),
                                          width=x2 - x1,
                                          height=y2 - y1,
                                          fc='none',
                                          ec='r')
                    plt.gca().add_patch(patch)
                plt.title("image %s\nrank%d , index %d, Panel=%d" %
                          (img_path, rank, img_num, pid))
                if args.plot == -1:
                    plt.show()
                else:
                    plt.draw()
                    plt.pause(args.plot)

        kept_bboxes = [
            bboxes[i_bb] for i_bb in range(len(bboxes)) if is_a_keeper[i_bb]
        ]

        tot_pix = [(j2 - j1) * (i2 - i1)
                   for i_bb, (i1, i2, j1, j2) in enumerate(kept_bboxes)]
        Ntot += sum(tot_pix)
        if rank == 0:
            print("%g total pixels (file %d / %d)" %
                  (Ntot, img_num + 1, len(my_shots)))
        all_kept_bbox += map(list, kept_bboxes)
        all_is_kept_flags += [(fname_idx, shot_idx, is_a_keeper)
                              ]  # store this information, write to disk

    # close the open hdf5 files so we can write to them again
    for h in my_open_files.values():
        h.close()

    print("END OF LOOP")
    print("Rank %d; total bboxes=%d; Total pixels=%g" %
          (rank, len(all_kept_bbox), Ntot))
    all_kept_bbox = MPI.COMM_WORLD.gather(all_kept_bbox, root=0)
    all_is_kept_flags = MPI.COMM_WORLD.gather(all_is_kept_flags, root=0)

    if rank == 0:
        all_kept_bbox = [
            bbox for bbox_lst in all_kept_bbox for bbox in bbox_lst
        ]
        Ntot_pix = sum([(j2 - j1) * (i2 - i1)
                        for i1, i2, j1, j2 in all_kept_bbox])
        print
        print("<><><><><><><<><><><><><><><><><><><><><><>")
        print("I am root. total bboxes=%d, Total pixels=%g" %
              (len(all_kept_bbox), Ntot_pix))
        print("<><><><><><><<><><><><><><><><><><><><><><>")
        print

        print("I am root. I will store flags for each bbox on each shot")

        all_flag_info = [i for sl in all_is_kept_flags for i in sl]  # flatten

        # open the hdf5 files in read+write mode and store the bbox keeper flags
        h5s = {i_f: h5py_File(f, "r+") for i_f, f in enumerate(fnames)}

        for i_info, (fidx, shot_idx, keeper_flags) in enumerate(all_flag_info):
            bbox_grp = h5s[fidx]["bboxes"]

            flag_name = "%s%d" % (args.keeperstag, shot_idx)

            if flag_name in bbox_grp:
                del bbox_grp[flag_name]

            bbox_grp.create_dataset(flag_name,
                                    data=keeper_flags,
                                    dtype=bool,
                                    compression='lzf')

            if i_info % 5 == 0:
                print("I am root. I saved bbox selection flags ( %d / %d ) " %
                      (i_info + 1, len(all_flag_info)))

        # close the open files..
        for h in h5s.values():
            h.close()
Пример #7
0
    def load(self):
        # some parameters
        # NOTE: for reference, inside each h5 file there is
        #   [u'Amatrices', u'Hi', u'bboxes', u'h5_path']
        # get the total number of shots using worker 0
        if rank == 0:
            self.time_load_start = time.time()
            print("I am root. I am calculating total number of shots")
            h5s = [h5py_File(f, "r") for f in self.fnames]
            Nshots_per_file = [h["h5_path"].shape[0] for h in h5s]
            Nshots_tot = sum(Nshots_per_file)
            print("I am root. Total number of shots is %d" % Nshots_tot)

            print("I am root. I will divide shots amongst workers.")
            shot_tuples = []
            roi_per = []
            for i_f, fname in enumerate(self.fnames):
                fidx_shotidx = [(i_f, i_shot) for i_shot in range(Nshots_per_file[i_f])]
                shot_tuples += fidx_shotidx

                # store the number of usable roi per shot in order to divide shots amongst ranks equally
                roi_per += [sum(h5s[i_f]["bboxes"]["%s%d" % (args.keeperstag, i_shot)][()])
                            for i_shot in range(Nshots_per_file[i_f])]

            from numpy import array_split
            from numpy.random import permutation
            print ("I am root. Number of uniques = %d" % len(set(shot_tuples)))

            # divide the array into chunks of roughly equal sum (total number of ROI)
            if args.partition and args.restartfile is None and args.xinitfile is None:
                diff = np.inf
                roi_per = np.array(roi_per)
                tstart = time.time()
                best_order = range(len(roi_per))
                print("Partitioning for better load balancing across ranks.. ")
                while 1:
                    order = permutation(len(roi_per))
                    res = [sum(a) for a in np.array_split(roi_per[order], size)]
                    new_diff = max(res) - min(res)
                    t_elapsed = time.time() - tstart
                    t_remain = args.partitiontime - t_elapsed
                    if new_diff < diff:
                        diff = new_diff
                        best_order = order.copy()
                        print("Best diff=%d, Parition time remaining: %.3f seconds" % (diff, t_remain))
                    if t_elapsed > args.partitiontime:
                        break
                shot_tuples = [shot_tuples[i] for i in best_order]

            elif args.partition and args.restartfile is not None:
                print ("Warning: skipping partitioning time to use shot mapping as laid out in restart file dir")
            else:
                print ("Proceeding without partitioning")

            # optional to divide into a sub group
            shot_tuples = array_split(shot_tuples, args.ngroups)[args.groupId]
            shots_for_rank = array_split(shot_tuples, size)
            import os  # FIXME, I thought I was imported already!
            if args.outdir is not None:  # save for a fast restart (shot order is important!)
                np.save(os.path.join(args.outdir, "shots_for_rank"), shots_for_rank)
            if args.restartfile is not None:
                # the directory containing the restart file should have a shots for rank file
                dirname = os.path.dirname(args.restartfile)
                print ("Loading shot mapping from dir %s" % dirname)
                shots_for_rank = np.load(os.path.join(dirname, "shots_for_rank.npy"))
                # propagate the shots for rank file...
                if args.outdir is not None:
                    np.save(os.path.join(args.outdir, "shots_for_rank"), shots_for_rank)
            if args.xinitfile is not None:
                # the directory containing the restart file should have a shots for rank file
                dirname = os.path.dirname(args.xinitfile)
                print ("Loading shot mapping from dir %s" % dirname)
                shots_for_rank = np.load(os.path.join(dirname, "shots_for_rank.npy"))
                # propagate the shots for rank file...
                if args.outdir is not None:
                    np.save(os.path.join(args.outdir, "shots_for_rank"), shots_for_rank)

            # close the open h5s..
            for h in h5s:
                h.close()

        else:
            Nshots_tot = None
            shots_for_rank = None
            h5s = None

        # Nshots_tot = comm.bcast( Nshots_tot, root=0)
        if has_mpi:
            shots_for_rank = comm.bcast(shots_for_rank, root=0)
        # h5s = comm.bcast( h5s, root=0)  # pull in the open hdf5 files

        my_shots = shots_for_rank[rank]
        if self.Nload is not None:
            start = 0
            if args.loadstart is not None:
                start = args.loadstart
            my_shots = my_shots[start: start + self.Nload]
        print("Rank %d: I will load %d shots, first shot: %s, last shot: %s"
              % (comm.rank, len(my_shots), my_shots[0], my_shots[-1]))

        # open the unique filenames for this rank
        # TODO: check max allowed pointers to open hdf5 file
        import h5py
        my_unique_fids = set([fidx for fidx, _ in my_shots])
        self.my_open_files = {fidx: h5py_File(self.fnames[fidx], "r") for fidx in my_unique_fids}
        # for fidx in my_unique_fids:
        #    fpath = self.fnames[fidx]
        #    if args.imgdirname is not None:
        #        fpath = fpath.split("/kaladin/")[1]
        #        fpath = os.path.join(args.imgdirname, fpath)
        #    self.my_open_files[fidx] = h5py.File(fpath, "r")
        Ntot = 0

        for img_num, (fname_idx, shot_idx) in enumerate(my_shots):
            h = self.my_open_files[fname_idx]

            # load the dxtbx image data directly:
            npz_path = h["h5_path"][shot_idx]

            if args.imgdirname is not None:
                import os
                npz_path = npz_path.split("/kaladin/")[1]
                npz_path = os.path.join(args.imgdirname, npz_path)

            if args.noiseless:
                noiseless_path = npz_path.replace(".npz", ".noiseless.npz")
                img_handle = numpy_load(noiseless_path)
            else:
                img_handle = numpy_load(npz_path)

            img = img_handle["img"]

            if len(img.shape) == 2:  # if single panel
                img = array([img])

            B = beam_from_dict(img_handle["beam"][()])

            log_init_crystal_scale = 0  # default
            if args.usepreoptscale:
                log_init_crystal_scale = h["crystal_scale_%s" % args.preopttag][shot_idx]
            # get the indexed crystal Amatrix
            Amat = h["Amatrices"][shot_idx]
            if args.usepreoptAmat:
                Amat = h["Amatrices_%s" % args.preopttag][shot_idx]
            amat_elems = list(sqr(Amat).inverse().elems)
            # real space basis vectors:
            a_real = amat_elems[:3]
            b_real = amat_elems[3:6]
            c_real = amat_elems[6:]

            # dxtbx indexed crystal model
            C = Crystal(a_real, b_real, c_real, "P43212")

            # change basis here ? Or maybe just average a/b
            a, b, c, _, _, _ = C.get_unit_cell().parameters()
            a_init = .5 * (a + b)
            c_init = c

            # shoe boxes where we expect spots
            bbox_dset = h["bboxes"]["shot%d" % shot_idx]
            n_bboxes_total = bbox_dset.shape[0]
            # is the shoe box within the resolution ring and does it have significant SNR (see filter_bboxes.py)
            is_a_keeper = h["bboxes"]["%s%d" % (args.keeperstag, shot_idx)][()]
            # tilt plane to the background pixels in the shoe boxes
            tilt_abc_dset = h["tilt_abc"]["shot%d" % shot_idx]
            # miller indices (not yet reduced by symm equivs)
            Hi_dset = h["Hi"]["shot%d" % shot_idx]
            try:
                panel_ids_dset = h["panel_ids"]["shot%d" % shot_idx]
                has_panels = True
            except KeyError:
                has_panels = False

            # apply the filters:
            bboxes = [bbox_dset[i_bb] for i_bb in range(n_bboxes_total) if is_a_keeper[i_bb]]
            tilt_abc = [tilt_abc_dset[i_bb] for i_bb in range(n_bboxes_total) if is_a_keeper[i_bb]]
            Hi = [tuple(Hi_dset[i_bb]) for i_bb in range(n_bboxes_total) if is_a_keeper[i_bb]]
            proc_file_idx = [i_bb for i_bb in range(n_bboxes_total) if is_a_keeper[i_bb]]

            if has_panels:
                panel_ids = [panel_ids_dset[i_bb] for i_bb in range(n_bboxes_total) if is_a_keeper[i_bb]]
            else:
                panel_ids = [0] * len(tilt_abc)

            # how many pixels do we have
            tot_pix = [(j2 - j1) * (i2 - i1) for i1, i2, j1, j2 in bboxes]
            Ntot += sum(tot_pix)

            # load some ground truth data from the simulation dumps (e.g. spectrum)
            # h5_fname = h["h5_path"][shot_idx].replace(".npz", "")
            h5_fname = npz_path.replace(".npz", "")
            if args.character is not None:
                h5_fname = h5_fname.replace("rock", args.character)
            if args.testmode2:
                h5_fname = npz_path.split(".npz")[0]
            data = h5py_File(h5_fname, "r")

            xtal_scale_truth = data["spot_scale"][()]
            tru = sqr(data["crystalA"][()]).inverse().elems
            a_tru = tru[:3]
            b_tru = tru[3:6]
            c_tru = tru[6:]
            C_tru = Crystal(a_tru, b_tru, c_tru, "P43212")

            fluxes = data["spectrum"][()]
            es = data["exposure_s"][()]

            # comm.Barrier()
            # exit()
            fluxes *= es  # multiply by the exposure time
            # TODO: wavelens should come from the imageset file itself
            if "wavelengths" in data.keys():
                wavelens = data["wavelengths"][()]
            else:
                raise KeyError("Wavelengths missing from hdf5 data")
                #from cxid9114.parameters import WAVELEN_HIGH
                #wavelens = [WAVELEN_HIGH]

            spectrum = zip(wavelens, fluxes)
            # dont simulate when there are no photons!
            spectrum = [(wave, flux) for wave, flux in spectrum if flux > self.flux_min]

            if args.forcemono:
                spectrum = [(B.get_wavelength(), sum(fluxes))]

            # make a unit cell manager that the refiner will use to track the B-matrix
            aa, _, cc, _, _, _ = C_tru.get_unit_cell().parameters()
            ucell_man = TetragonalManager(a=a_init, c=c_init)

            # create the sim_data instance that the refiner will use to run diffBragg
            # create a nanoBragg crystal
            self.Fhkl_obs = open_flex(args.Fobs).as_amplitude_array()
            self.Fhkl_ref = None
            if args.Fref is not None:
                self.Fhkl_ref = open_flex(
                    args.Fref).as_amplitude_array()  # this reference miller array is used to track CC and R-factor
            if img_num == 0:  # only initialize the simulator after loading the first image
                self.initialize_simulator(C, B, spectrum, self.Fhkl_obs)

            # map the miller array to ASU
            Hi_asu = map_hkl_list(Hi, self.anomalous_flag, self.symbol)

            # copy the image as photons (NOTE: Dont forget to ditch its references!)
            img_in_photons = (img / args.gainval).astype('float32')

            # Here, takeout from the image only whats necessary to perform refinement
            # first filter the spot rois so they dont occur exactly at the boundary of the image (inclusive range in nB)
            assert len(img_in_photons.shape) == 3  # sanity
            nslow, nfast = img_in_photons[0].shape
            bboxes = array(bboxes)
            for i_bbox, (_, x2, _, y2) in enumerate(bboxes):
                if x2 == nfast:
                    bboxes[i_bbox][1] = x2 - 1  # update roi_xmax
                if y2 == nslow:
                    bboxes[i_bbox][3] = y2 - 1  # update roi_ymax
            # now cache the roi in nanoBragg format ((x1,x2), (y1,y1))
            # and also cache the pixels and the coordinates

            nanoBragg_rois = []  # special nanoBragg format
            xrel, yrel, roi_img = [], [], []
            for i_roi, (x1, x2, y1, y2) in enumerate(bboxes):
                nanoBragg_rois.append(((x1, x2), (y1, y2)))
                yr, xr = np_indices((y2 - y1 + 1, x2 - x1 + 1))
                xrel.append(xr)
                yrel.append(yr)
                pid = panel_ids[i_roi]
                roi_img.append(img_in_photons[pid, y1:y2 + 1, x1:x2 + 1])

            # make sure to clear that damn memory
            img = None
            img_in_photons = None
            del img  # not sure if needed here..
            del img_in_photons

            # peak at the memory usage of this rank
            # mem = getrusage(RUSAGE_SELF).ru_maxrss  # peak mem usage in KB
            # mem = mem / 1e6  # convert to GB
            mem = self._usage()

            # print "RANK %d: %.2g total pixels in %d/%d bboxes (file %d / %d); MemUsg=%2.2g GB" \
            #      % (rank, Ntot, len(bboxes), n_bboxes_total,  img_num +1, len(my_shots), mem)
            self.all_pix += Ntot

            # accumulate per-shot information
            self.global_image_id[img_num] = None  # TODO
            self.all_spot_roi[img_num] = bboxes
            self.all_abc_inits[img_num] = tilt_abc
            self.all_panel_ids[img_num] = panel_ids
            self.all_ucell_mans[img_num] = ucell_man
            self.all_spectra[img_num] = spectrum
            self.all_crystal_models[img_num] = C
            self.log_of_init_crystal_scales[
                img_num] = log_init_crystal_scale  # these should be the log of the initial crystal scale
            self.all_crystal_scales[img_num] = xtal_scale_truth
            self.all_crystal_GT[img_num] = C_tru
            self.all_xrel[img_num] = xrel
            self.all_yrel[img_num] = yrel
            self.all_nanoBragg_rois[img_num] = nanoBragg_rois
            self.all_roi_imgs[img_num] = roi_img
            self.all_fnames[img_num] = npz_path
            self.all_proc_fnames[img_num] = h.filename
            self.all_Hi[img_num] = Hi
            self.all_Hi_asu[img_num] = Hi_asu
            self.all_proc_idx[img_num] = proc_file_idx
            self.all_shot_idx[img_num] = shot_idx  # this is the index of the shot in the process*h5 file
            # NOTE all originZ for each panel are the same in the simulated data.. Not necessarily true for real data
            shot_originZ = self.SIM.detector[0].get_origin()[2]
            self.shot_originZ_init[img_num] = shot_originZ
            print(img_num)

        for h in self.my_open_files.values():
            h.close()
Пример #8
0
def r3_dnn_apply_keras(target_dirname,
                       old_stft_obj=None,
                       cuda=False,
                       saving_to_disk=True):
    LOGGER.info(
        '{}: r3: Denoising original stft with neural network model...'.format(
            target_dirname))
    '''
    r3_dnn_apply takes an old_stft object (or side effect load from disk)
    and saves a new_stft object
    '''
    scan_battery_dirname = os_path_dirname(target_dirname)
    model_dirname = os_path_dirname(os_path_dirname(scan_battery_dirname))

    # load stft data
    if old_stft_obj is None:
        old_stft_fpath = os_path_join(target_dirname, 'old_stft.mat')
        with h5py_File(old_stft_fpath, 'r') as f:
            stft = np_concatenate(
                [f['old_stft_real'][:], f['old_stft_imag'][:]], axis=1)
    else:
        stft = np_concatenate(
            [old_stft_obj['old_stft_real'], old_stft_obj['old_stft_imag']],
            axis=1)

    N_beams, N_elements_2, N_segments, N_fft = stft.shape
    N_elements = N_elements_2 // 2

    # combine stft_real and stft_imag

    # move element position axis
    stft = np_moveaxis(stft, 1, 2)  # TODO: Duplicate?

    # reshape the to flatten first two axes
    stft = np_reshape(
        stft, [N_beams * N_segments, N_elements_2, N_fft])  # TODO: Duplicate?

    # process stft with networks
    k_mask = list(range(3, 6))
    for frequency in k_mask:
        process_each_frequency_keras(model_dirname, stft, frequency)

    # reshape the stft data
    stft = np_reshape(
        stft, [N_beams, N_segments, N_elements_2, N_fft])  # TODO: Duplicate?

    # set zero outside analysis frequency range
    discard_mask = np_ones_like(stft, dtype=bool)
    discard_mask[:, :, :, k_mask] = False  # pylint: disable=E1137
    stft[discard_mask] = 0
    del discard_mask

    # mirror data to negative frequencies using conjugate symmetry
    end_index = N_fft // 2
    stft[:, :, :, end_index + 1:] = np_flip(stft[:, :, :, 1:end_index], axis=3)
    stft[:, :, N_elements:2 * N_elements, end_index +
         1:] = -1 * stft[:, :, N_elements:2 * N_elements, end_index + 1:]

    # move element position axis
    stft = np_moveaxis(stft, 1, 2)  # TODO: Duplicate?

    # change variable names
    # new_stft_real = stft[:, :N_elements, :, :]
    new_stft_real = stft[:, :N_elements, :, :].transpose()
    # new_stft_imag = stft[:, N_elements:, :, :]
    new_stft_imag = stft[:, N_elements:, :, :].transpose()

    del stft

    # change dimensions
    # new_stft_real = new_stft_real.transpose()
    # new_stft_imag = new_stft_imag.transpose()

    # save new stft data
    new_stft_obj = {
        'new_stft_real': new_stft_real,
        'new_stft_imag': new_stft_imag
    }
    if saving_to_disk is True:
        new_stft_fname = os_path_join(target_dirname, 'new_stft.mat')
        savemat(new_stft_fname, new_stft_obj)
    LOGGER.info('{}: r3 Done.'.format(target_dirname))
    return new_stft_obj