Beispiel #1
0
def load_torch_zip(directory, filename, varnames, cuda=False, numpy=False):
    filename = os.path.join(directory, filename)
    loaded = numpy_load(filename)
    result = []
    for k in varnames:
        if k in loaded:
            val = loaded[k]
            if len(val.shape) == 0:
                # Load scalars as scalars
                result.append(val[()])
            else:
                result.append(torch.from_numpy(loaded[k]))
        elif ('%s indices' % k) in loaded:
            values = torch.from_numpy(loaded['%s values' % k])
            indices = torch.from_numpy(loaded['%s indices' % k])
            shape = loaded['%s shape' % k]
            i_type, SparseTensor = sparse_types_for_dense_type(type(values))
            result.append(SparseTensor(indices, values, torch.Size(shape)))
        else:
            result.append(None)
    if cuda:
        result = [r.cuda() if hasattr(r, 'cuda') else r for r in result]
    elif numpy:
        result = [r.numpy() if hasattr(r, 'numpy') else r for r in result]
    return result
Beispiel #2
0
    def __getitem__(self, indices):
        '''x.__getitem__(indices) <==> x[indices]

    Returns a numpy array.

        '''
        array = numpy_load(self._partition_file)

        indices = parse_indices(array.shape, indices)

        array = get_subspace(array, indices)

        if self._get_component('_masked_as_record'):
            # Convert a record array to a masked array
            array = numpy_ma_array(array['_data'], mask=array['_mask'],
                                   copy=False)
            array.shrink_mask()

        # Return the numpy array
        return array
Beispiel #3
0
def main():
    # some parameters
    int_radius = 5
    gain = args.gain
    # data is stored in 39 h5py_Files
    resmin = args.reshigh  # high res cutoff
    resmax = args.reslow  # low res cutoff
    fnames = glob(args.glob)
    if rank == 0:
        print("CMD looked like:")
        print(" ".join(sys.argv))

    # NOTE: for reference, inside each h5 file there is
    #   [u'Amatrices', u'Hi', u'bboxes', u'h5_path']

    # get the total number of shots using worker 0
    if rank == 0:
        print("I am root. I am calculating total number of shots")
        h5s = [h5py_File(f, "r") for f in fnames]
        Nshots_per_file = [h["h5_path"].shape[0] for h in h5s]
        Nshots_tot = sum(Nshots_per_file)
        print("I am root. Total number of shots is %d" % Nshots_tot)

        print("I am root. I will divide shots amongst workers.")
        shot_tuples = []
        for i_f, fname in enumerate(fnames):
            fidx_shotidx = [(i_f, i_shot)
                            for i_shot in range(Nshots_per_file[i_f])]
            shot_tuples += fidx_shotidx

        from numpy import array_split
        print("I am root. Number of uniques = %d" % len(set(shot_tuples)))
        shots_for_rank = array_split(shot_tuples, size)

        # close the open h5s..
        for h in h5s:
            h.close()

    else:
        Nshots_tot = None
        shots_for_rank = None
        h5s = None

    # Nshots_tot = comm.bcast( Nshots_tot, root=0)
    if has_mpi:
        shots_for_rank = comm.bcast(shots_for_rank, root=0)
    # h5s = comm.bcast( h5s, root=0)  # pull in the open hdf5 files

    my_shots = shots_for_rank[rank]

    # open the unique filenames for this rank
    # TODO: check max allowed pointers to open hdf5 file
    my_unique_fids = set([fidx for fidx, _ in my_shots])
    my_open_files = {
        fidx: h5py_File(fnames[fidx], "r")
        for fidx in my_unique_fids
    }

    all_num_kept = 0
    all_num_below = 0
    Ntot = 0
    all_kept_bbox = []
    all_is_kept_flags = []
    for img_num, (fname_idx, shot_idx) in enumerate(my_shots):

        h = my_open_files[fname_idx]

        # load the dxtbx image data directly:
        img_path = h["h5_path"][shot_idx]
        try:
            img_path = img_path.decode()
        except AttributeError:
            pass
        img_data = numpy_load(img_path)["img"]
        bboxes = h["bboxes"]["shot%d" % shot_idx][()]
        panel_ids = h["panel_ids"]["shot%d" % shot_idx][()]
        hi, ki, li = h["Hi"]["shot%d" % shot_idx][()].T
        nspots = len(bboxes)

        if "below_zero" in list(h.keys()):
            # tilt dips below zero
            below_zero = h["below_zero"]["shot%d" % shot_idx][()]
        else:
            below_zero = None

        # use the known cell to compute the resolution of the spots
        #FIXME get the hard coded unit cell outta here!

        if args.resoinfile:
            reso = h["resolution"]["shot%d" % shot_idx][()]
        else:
            if args.p9:
                reso = 1 / sqrt((hi**2 + ki**2) / 114 / 114 +
                                li**2 / 32.5 / 32.5)
            else:
                # TODO: why does 0,0,0 ever appear as a reflection ? Should never happen...
                reso = 1 / sqrt((hi**2 + ki**2) / 79.1 / 79.1 +
                                li**2 / 38.4 / 38.4)

        in_reso_ring = array([resmin <= d < resmax for d in reso])

        # Dirty integrater, sets integration region as disk of diameter 2*int_radius pixels
        if len(img_data.shape) == 2:  # single panel image
            assert len(set(panel_ids)) == 1  # sanity check
            img_data = [img_data]

        is_a_keeper = [in_reso_ring[i_spot] for i_spot in range(nspots)]

        print("In reso: %d" % sum(is_a_keeper))

        hgroups = h.keys()

        if args.snrmin is not None:
            if "SNR_Leslie99" in hgroups:
                SNR = h["SNR_Leslie99"]["shot%d" % shot_idx][()]
            else:
                if rank == 0:
                    print("WARNING USING DIRTY SNR ESTIMATE!")
                dirties = {
                    pid: Integrator(img_data[pid],
                                    int_radius=int_radius,
                                    gain=gain)
                    for pid in set(panel_ids)
                }

                int_data = [
                    dirties[pid].integrate_bbox_dirty(bb)
                    for pid, bb in zip(panel_ids, bboxes)
                ]

                # signal, background, variance  # these are from the paper Leslie '99
                s, b, var = map(array, zip(*int_data))
                SNR = s / sqrt(var)
            is_a_keeper = [
                k and snr >= args.snrmin for k, snr in zip(is_a_keeper, SNR)
            ]

        print("In reso and SNR: %d" % sum(is_a_keeper))

        if "tilt_rms" in hgroups:
            if args.tiltfilt is not None:
                tilt_rms = h["tilt_rms"]["shot%d" % shot_idx][()]
                is_a_keeper = [
                    k and rms < args.tiltfilt
                    for k, rms in zip(is_a_keeper, tilt_rms)
                ]

        if "tilt_error" in hgroups:
            if args.tilterrmax is not None:
                tilt_err = h["tilt_error"]["shot%d" % shot_idx][()]
                is_a_keeper = [
                    k and err <= args.tilterrmax
                    for k, err in zip(is_a_keeper, tilt_err)
                ]
            elif args.tilterrperc is not None:
                assert 0 < args.tilterrperc < 100
                tilt_err = h["tilt_error"]["shot%d" % shot_idx][()]
                val = percentile(tilt_err, args.tilterrperc)
                is_a_keeper = [
                    k and err < val for k, err in zip(is_a_keeper, tilt_err)
                ]
        else:
            if rank == 0:
                print("WARNING: tilt_error not in hdf5 file")

        if "indexed_flag" in hgroups:
            #TODO change me to assume indexed_flag is a bool
            if not args.notindexed:
                indexed_flag = h["indexed_flag"]["shot%d" % shot_idx][()]
                is_a_keeper = [
                    k and (idx > 0)
                    for k, idx in zip(is_a_keeper, indexed_flag)
                ]
        else:
            if rank == 0:
                print("WARNING: indexed_flag not in hdf5 file")

        if "is_on_boundary" in hgroups:
            if not args.onboundary:
                on_boundary = h["is_on_boundary"]["shot%d" % shot_idx][()]
                is_a_keeper = [
                    k and not onbound
                    for k, onbound in zip(is_a_keeper, on_boundary)
                ]
        else:
            if rank == 0:
                print("WARNING: is_on_boundary not in hdf5 file")

        if args.keeperstride is not None:
            from numpy import where
            where_kept = where(is_a_keeper)[0]
            keeper_pos = where_kept[::args.keeperstride]
            for w in where_kept:
                if w not in keeper_pos:
                    is_a_keeper[w] = False

        nkept = np_sum(is_a_keeper)
        if rank == 0:
            print("Keeping %d out of %d spots" % (nkept, nspots))
        if below_zero is not None:
            num_below = np_sum(below_zero[is_a_keeper])
            all_num_below += num_below
            all_num_kept += nkept

        if rank == 0 and args.plot is not None:
            for pid in set(panel_ids):
                plt.gcf().clear()
                plt.imshow(img_data[pid], vmax=250, cmap='viridis')
                for i_spot in range(nspots):
                    if not is_a_keeper[i_spot]:
                        continue
                    if not panel_ids[i_spot] == pid:
                        continue
                    x1, x2, y1, y2 = bboxes[i_spot]
                    patch = plt.Rectangle(xy=(x1, y1),
                                          width=x2 - x1,
                                          height=y2 - y1,
                                          fc='none',
                                          ec='r')
                    plt.gca().add_patch(patch)
                plt.title("Panel=%d" % pid)
                if args.plot == -1:
                    plt.show()
                else:
                    plt.draw()
                    plt.pause(args.plot)

        kept_bboxes = [
            bboxes[i_bb] for i_bb in range(len(bboxes)) if is_a_keeper[i_bb]
        ]

        tot_pix = [(j2 - j1) * (i2 - i1)
                   for i_bb, (i1, i2, j1, j2) in enumerate(kept_bboxes)]
        Ntot += sum(tot_pix)
        if rank == 0:
            print("%g total pixels (file %d / %d)" %
                  (Ntot, img_num + 1, len(my_shots)))
        all_kept_bbox += map(list, kept_bboxes)
        all_is_kept_flags += [(fname_idx, shot_idx, is_a_keeper)
                              ]  # store this information, write to disk

    # close the open hdf5 files so we can write to them again
    for h in my_open_files.values():
        h.close()

    print("END OF LOOP")
    print("Rank %d; total bboxes=%d; Total pixels=%g" %
          (rank, len(all_kept_bbox), Ntot))
    all_kept_bbox = MPI.COMM_WORLD.gather(all_kept_bbox, root=0)
    all_is_kept_flags = MPI.COMM_WORLD.gather(all_is_kept_flags, root=0)

    all_kept = comm.reduce(all_num_kept)
    all_below = comm.reduce(all_num_below)
    if rank == 0:
        print("TOTAL BBOXES KEPT = %d; TOTAL BBOXES BELOW ZERO=%d" %
              (all_kept, all_below))
    if rank == 0:
        all_kept_bbox = [
            bbox for bbox_lst in all_kept_bbox for bbox in bbox_lst
        ]
        Ntot_pix = sum([(j2 - j1) * (i2 - i1)
                        for i1, i2, j1, j2 in all_kept_bbox])
        print("\n<><><><><><><<><><><><><><><><><><><><><><>")
        print("I am root. total bboxes=%d, Total pixels=%g" %
              (len(all_kept_bbox), Ntot_pix))
        print("<><><><><><><<><><><><><><><><><><><><><><>\n")

        print("I am root. I will store flags for each bbox on each shot")

        all_flag_info = [i for sl in all_is_kept_flags for i in sl]  # flatten

        # open the hdf5 files in read+write mode and store the bbox keeper flags
        h5s = {i_f: h5py_File(f, "r+") for i_f, f in enumerate(fnames)}

        for i_info, (fidx, shot_idx, keeper_flags) in enumerate(all_flag_info):
            bbox_grp = h5s[fidx]["bboxes"]

            flag_name = "%s%d" % (args.keeperstag, shot_idx)

            if flag_name in bbox_grp:
                del bbox_grp[flag_name]

            bbox_grp.create_dataset(flag_name,
                                    data=keeper_flags,
                                    dtype=bool,
                                    compression='lzf')

            if i_info % 5 == 0:
                print("I am root. I saved bbox selection flags ( %d / %d ) " %
                      (i_info + 1, len(all_flag_info)))

        # close the open files..
        for h in h5s.values():
            h.close()
Beispiel #4
0
    def load(self):

        # some parameters

        # NOTE: for reference, inside each h5 file there is
        #   [u'Amatrices', u'Hi', u'bboxes', u'h5_path']

        # get the total number of shots using worker 0
        if rank == 0:
            print("I am root. I am calculating total number of shots")
            h5s = [h5py_File(f, "r") for f in self.fnames]
            Nshots_per_file = [h["h5_path"].shape[0] for h in h5s]
            Nshots_tot = sum(Nshots_per_file)
            print("I am root. Total number of shots is %d" % Nshots_tot)

            print("I am root. I will divide shots amongst workers.")
            shot_tuples = []
            for i_f, fname in enumerate(self.fnames):
                fidx_shotidx = [(i_f, i_shot)
                                for i_shot in range(Nshots_per_file[i_f])]
                shot_tuples += fidx_shotidx

            from numpy import array_split
            print("I am root. Number of uniques = %d" % len(set(shot_tuples)))
            shots_for_rank = array_split(shot_tuples, size)

            # close the open h5s..
            for h in h5s:
                h.close()

        else:
            Nshots_tot = None
            shots_for_rank = None
            h5s = None

        #Nshots_tot = comm.bcast( Nshots_tot, root=0)
        if has_mpi:
            shots_for_rank = comm.bcast(shots_for_rank, root=0)
        #h5s = comm.bcast( h5s, root=0)  # pull in the open hdf5 files

        my_shots = shots_for_rank[rank]
        if self.Nload is not None:
            my_shots = my_shots[:self.Nload]

        # open the unique filenames for this rank
        # TODO: check max allowed pointers to open hdf5 file
        my_unique_fids = set([fidx for fidx, _ in my_shots])
        my_open_files = {
            fidx: h5py_File(self.fnames[fidx], "r")
            for fidx in my_unique_fids
        }
        Ntot = 0
        self.all_bbox_pixels = []
        for img_num, (fname_idx, shot_idx) in enumerate(my_shots):
            if img_num == args.Nmax:
                #print("Already processed maximum number images!")
                continue
            h = my_open_files[fname_idx]

            # load the dxtbx image data directly:
            npz_path = h["h5_path"][shot_idx]
            # NOTE take me out!
            if args.testmode:
                import os
                npz_path = os.path.basename(npz_path)
            img_handle = numpy_load(npz_path)
            img = img_handle["img"]

            if len(img.shape) == 2:  # if single panel>>
                img = np.array([img])

            #D = det_from_dict(img_handle["det"][()])
            B = beam_from_dict(img_handle["beam"][()])

            # get the indexed crystal Amatrix
            Amat = h["Amatrices"][shot_idx]
            amat_elems = list(sqr(Amat).inverse().elems)
            # real space basis vectors:
            a_real = amat_elems[:3]
            b_real = amat_elems[3:6]
            c_real = amat_elems[6:]

            # dxtbx indexed crystal model
            C = Crystal(a_real, b_real, c_real, "P43212")

            # change basis here ? Or maybe just average a/b
            a, b, c, _, _, _ = C.get_unit_cell().parameters()
            a_init = .5 * (a + b)
            c_init = c

            # shoe boxes where we expect spots
            bbox_dset = h["bboxes"]["shot%d" % shot_idx]
            n_bboxes_total = bbox_dset.shape[0]
            # is the shoe box within the resolution ring and does it have significant SNR (see filter_bboxes.py)
            is_a_keeper = h["bboxes"]["keepers%d" % shot_idx][()]

            # tilt plane to the background pixels in the shoe boxes
            tilt_abc_dset = h["tilt_abc"]["shot%d" % shot_idx]
            try:
                panel_ids_dset = h["panel_ids"]["shot%d" % shot_idx]
                has_panels = True
            except KeyError:
                has_panels = False

            # apply the filters:
            bboxes = [
                bbox_dset[i_bb] for i_bb in range(n_bboxes_total)
                if is_a_keeper[i_bb]
            ]
            tilt_abc = [
                tilt_abc_dset[i_bb] for i_bb in range(n_bboxes_total)
                if is_a_keeper[i_bb]
            ]
            if has_panels:
                panel_ids = [
                    panel_ids_dset[i_bb] for i_bb in range(n_bboxes_total)
                    if is_a_keeper[i_bb]
                ]
            else:
                panel_ids = [0] * len(tilt_abc)

            # how many pixels do we have
            tot_pix = [(j2 - j1) * (i2 - i1) for i1, i2, j1, j2 in bboxes]
            Ntot += sum(tot_pix)

            # actually load the pixels...
            #data_boxes = [ img[j1:j2, i1:i2] for i1,i2,j1,j2 in bboxes]

            # Here we will try a per-shot refinement of the unit cell and Umatrix, as well as ncells abc
            # and spot scale etc..

            # load some ground truth data from the simulation dumps (e.g. spectrum)
            h5_fname = h["h5_path"][shot_idx].replace(".npz", "")
            # NOTE remove me
            if args.testmode:
                h5_fname = os.path.basename(h5_fname)
            data = h5py_File(h5_fname, "r")

            tru = sqr(data["crystalA"][()]).inverse().elems
            a_tru = tru[:3]
            b_tru = tru[3:6]
            c_tru = tru[6:]
            C_tru = Crystal(a_tru, b_tru, c_tru, "P43212")
            try:
                angular_offset_init = compare_with_ground_truth(
                    a_tru, b_tru, c_tru, [C], symbol="P43212")[0]
            except Exception as err:
                print(
                    "Rank %d: Boo cant use the comparison w GT function: %s" %
                    (rank, err))

            fluxes = data["spectrum"][()]
            es = data["exposure_s"][()]
            fluxes *= es  # multiply by the exposure time
            spectrum = zip(wavelens, fluxes)
            # dont simulate when there are no photons!
            spectrum = [(wave, flux) for wave, flux in spectrum
                        if flux > self.flux_min]

            # make a unit cell manager that the refiner will use to track the B-matrix
            aa, _, cc, _, _, _ = C_tru.get_unit_cell().parameters()
            ucell_man = TetragonalManager(a=a_init, c=c_init)
            if args.startwithtruth:
                ucell_man = TetragonalManager(a=aa, c=cc)

            # create the sim_data instance that the refiner will use to run diffBragg
            # create a nanoBragg crystal
            nbcryst = nanoBragg_crystal()
            nbcryst.dxtbx_crystal = C
            if args.startwithtruth:
                nbcryst.dxtbx_crystal = C_tru

            nbcryst.thick_mm = 0.1
            nbcryst.Ncells_abc = 30, 30, 30
            nbcryst.miller_array = Fhkl_guess.as_amplitude_array()
            nbcryst.n_mos_domains = 1
            nbcryst.mos_spread_deg = 0.0

            # create a nanoBragg beam
            nbbeam = nanoBragg_beam()
            nbbeam.size_mm = 0.001
            nbbeam.unit_s0 = B.get_unit_s0()
            nbbeam.spectrum = spectrum

            # sim data instance
            SIM = SimData()
            SIM.detector = CSPAD
            #SIM.detector = D
            SIM.crystal = nbcryst
            SIM.beam = nbbeam
            SIM.panel_id = 0  # default

            spot_scale = 12
            if args.sad:
                spot_scale = 1
            SIM.instantiate_diffBragg(default_F=0, oversample=0)
            SIM.D.spot_scale = spot_scale

            img_in_photons = img / self.gain

            print("Rank %d, Starting refinement!" % rank)
            try:
                RUC = RefineAllMultiPanel(
                    spot_rois=bboxes,
                    abc_init=tilt_abc,
                    img=img_in_photons,  # NOTE this is now a multi panel image
                    SimData_instance=SIM,
                    plot_images=args.plot,
                    plot_residuals=args.residual,
                    ucell_manager=ucell_man)

                RUC.panel_ids = panel_ids
                RUC.multi_panel = True
                RUC.split_evaluation = args.split
                RUC.trad_conv = True
                RUC.refine_detdist = False
                RUC.refine_background_planes = False
                RUC.refine_Umatrix = True
                RUC.refine_Bmatrix = True
                RUC.refine_ncells = True
                RUC.use_curvatures = False  # args.curvatures
                RUC.calc_curvatures = True  #args.curvatures
                RUC.refine_crystal_scale = True
                RUC.refine_gain_fac = False
                RUC.plot_stride = args.stride
                RUC.poisson_only = False
                RUC.trad_conv_eps = 5e-3  # NOTE this is for single panel model
                RUC.max_calls = 300
                RUC.verbose = False
                RUC.use_rot_priors = True
                RUC.use_ucell_priors = True
                if args.verbose:
                    if rank == 0:  # only show refinement stats for rank 0
                        RUC.verbose = True
                RUC.run()
                if RUC.hit_break_to_use_curvatures:
                    RUC.use_curvatures = True
                    RUC.run(setup=False)
            except AssertionError as err:
                print(
                    "Rank %d, filename %s Hit assertion error during refinement: %s"
                    % (rank, data.filename, err))
                continue

            angle, ax = RUC.get_correction_misset(as_axis_angle_deg=True)
            if args.startwithtruth:
                C = Crystal(a_tru, b_tru, c_tru, "P43212")
            C.rotate_around_origin(ax, angle)
            C.set_B(RUC.get_refined_Bmatrix())
            a_ref, _, c_ref, _, _, _ = C.get_unit_cell().parameters()
            # compute missorientation with ground truth model
            try:
                angular_offset = compare_with_ground_truth(a_tru,
                                                           b_tru,
                                                           c_tru, [C],
                                                           symbol="P43212")[0]
                print(
                    "Rank %d, filename=%s, ang=%f, init_ang=%f, a=%f, init_a=%f, c=%f, init_c=%f"
                    % (rank, data.filename, angular_offset,
                       angular_offset_init, a_ref, a_init, c_ref, c_init))
            except Exception as err:
                print("Rank %d, filename=%s, error %s" %
                      (rank, data.filename, err))

            # free the memory from diffBragg instance
            RUC.S.D.free_all()
            del img  # not sure if needed here..
            del img_in_photons

            if args.testmode:
                exit()

            # peak at the memory usage of this rank
            mem = getrusage(RUSAGE_SELF).ru_maxrss  # peak mem usage in KB
            mem = mem / 1e6  # convert to GB

            if rank == 0:
                print "RANK 0: %.2g total pixels in %d/%d bboxes (file %d / %d); MemUsg=%2.2g GB" \
                      % (Ntot, len(bboxes), n_bboxes_total,  img_num+1, len(my_shots), mem)

            # TODO: accumulate all pixels
            #self.all_bbox_pixels += data_boxes

        for h in my_open_files.values():
            h.close()

        print("Rank %d; all subimages loaded!" % rank)
Beispiel #5
0
def main():

    login_name, login_password, year = None, None, None
    year_current, today = datetime.now().strftime("%Y"), datetime.now().date()

    for i in range(len(argv)):

        if argv[i] == '-y' and argv[i + 1:]:
            if match(r"^[0-9]{4}$", argv[i + 1]):
                year = argv[i + 1]
            else:
                print("无效的考核年度!")
                return
        elif argv[i] == '-n' and argv[i + 1:]:
            login_name = argv[i + 1]
        elif argv[i] == '-p' and argv[i + 1:]:
            login_password = argv[i + 1]

    url_login_info = getRealPath("login_info_cq.npy")

    if login_name and login_password:
        numpy_save(url_login_info, [login_name, login_password])
    else:
        try:
            login_name, login_password = numpy_load(url_login_info)
        except:
            print("首次使用,请设置登录信息:-n 用户名 -p 密码!")
            return

    cq_post = CQPost()
    cq_post.login(login_name, login_password)

    # 开始读取excel
    url_seed, url_target = getRealPath("seed.xlsx"), getDesktopPath(
        "部门需求看板.xlsx")
    seed = pd.read_excel(url_seed).fillna("", inplace=False)
    groups = [seed.iloc[x, 0] for x in range(len(seed))]
    members = [[y for y in seed.iloc[x, 1].split("|") if y]
               for x in range(len(seed))]
    members_set = set([x for y in members for x in y])
    if sum([len(x) for x in members]) > len(members_set):
        print("成员列存在重复配置的项, 请检查seed文档!")
        return

    groups.append("其它")
    members.append([])

    # 每一列数据的写入具有同质性, 独立个函数出来
    def write_col(worksheet, index, head, col, format_head, format_cell):
        worksheet.set_column(index, index,
                             max(len_cell(head), *[len_cell(x) for x in col]))
        worksheet.write(0, index, head, format_head)
        worksheet.write_column(1, index, col, format_cell)

    with pd.ExcelWriter(url_target) as writer:
        df, workbook = pd.DataFrame(), writer.book
        format_head = workbook.add_format({
            'bold': True,
            'bg_color': '#00CD66',
            'border': 1,
            'align': 'left'
        })
        format_cell = workbook.add_format({
            "text_wrap": True,
            'align': 'left',
            'valign': "vcenter"
        })

        # **************************************************************************************************************
        # 处理实时数据, 先拉取
        cq_tmp = cq_post.pullCQ()
        cq_temp_dict = defaultdict(list)
        if cq_tmp['resultSetData']['rowData']:
            for item in cq_tmp['resultSetData']['rowData']:
                item['提出日期'] = strUtcTimeToLocalDate(item['提出日期'])
                item['计划更新UAT日期'] = strUtcTimeToLocalDate(item['计划更新UAT日期'])
                item['计划投产日期'] = strDateToDate(item['计划投产日期'])
                item['标准功能点数'] = float(item['标准功能点数'])
                cq_temp_dict[item['owner']].append(item)

        members[-1] = [x for x in cq_temp_dict.keys() if x not in members_set]
        df.to_excel(writer, index=False, header=False, sheet_name="实时统计")
        worksheet = writer.sheets['实时统计']

        xls = []
        # 处理column-项目组
        head, col = "项目组", groups
        xls.append([head, col])

        # 处理column-成员
        head, col = "成员", ['\n'.join(x) for x in members]
        xls.append([head, col])

        # 第三列以后都需要用到分组cq信息, 构建一次就行了
        cqs = [
            hstack([cq_temp_dict[y] for y in x]) if x else [] for x in members
        ]

        # 处理column-当前CQ单总数, 不包括驳回存档关闭
        head, col = "未关闭总数", [len(x) for x in cqs]
        xls.append([head, col])

        # 处理column-在开发CQ单总数
        state = {
            "等待开发", "正在开发", "等待安装SIT", "等待同步安装SIT", "等待SIT测试", "等待安装", "等待同步安装"
        }
        head, col = "开发阶段", [
            len([y for y in x if y['State'] in state]) for x in cqs
        ]
        xls.append([head, col])

        # 处理column-在测试CQ单总数
        state = {"等待检测", "正在检测", "等待同步检测"}
        head, col = "测试阶段", [
            len([y for y in x if y['State'] in state]) for x in cqs
        ]
        xls.append([head, col])

        # 处理column-等待投产CQ单总数
        state = {"等待投产"}
        head, col = "等待投产", [
            len([y for y in x if y['State'] in state]) for x in cqs
        ]
        xls.append([head, col])

        # 处理column-等待审核CQ单总数
        state = {"等待审核"}
        f_v = lambda x: x['问题编号'] + '-' + x['owner.fullname']
        f_if = lambda x: x['State'] in state
        head, col = "等待审核", [
            "\n".join([f_v(y) for y in x if f_if(y)]) for x in cqs
        ]
        xls.append([head, col])

        # 处理column-涉及功能点总数
        head, col = "总功能点", [int(sum([y['标准功能点数'] for y in x])) for x in cqs]
        xls.append([head, col])

        # 处理column-均功能点
        head, col = "均功能点", [
            int(
                sum([y['标准功能点数']
                     for y in cqs[i]]) / len(members[i]) if members[i] else 0)
            for i in range(len(cqs))
        ]
        xls.append([head, col])

        # 处理column-当天更新UAT, 不考虑等待审核的
        state = {
            "等待开发", "正在开发", "等待安装SIT", "等待同步安装SIT", "等待SIT测试", "等待安装", "等待同步安装"
        }
        f_v = lambda x: x['问题编号'] + '-' + x['State'] + "-" + x['owner.fullname'
                                                               ]
        f_if = lambda x: x['计划更新UAT日期'] == today and x['State'] in state
        head, col = "待更新UAT", [
            "\n".join([f_v(y) for y in x if f_if(y)]) for x in cqs
        ]
        xls.append([head, col])

        # 处理column-超期未更新UAT; 不考虑等待审核的; 这里的算法需要换一下, 目前统计的只是
        # 昨天及以前应该更新uat但没更新的, 而不是所有超期的(比如虽然翻了, 但是翻晚了)
        state = {
            "等待开发", "正在开发", "等待安装SIT", "等待同步安装SIT", "等待SIT测试", "等待安装", "等待同步安装"
        }
        f_v = lambda x: x['问题编号'] + '-' + disDay(x[
            '计划更新UAT日期'], today) + "天" + '-' + x['owner.fullname']
        f_if = lambda x: x['计划更新UAT日期'] < today and x['State'] in state
        head, col = "UAT逾期", [
            "\n".join([f_v(y) for y in x if f_if(y)]) for x in cqs
        ]
        xls.append([head, col])

        for i in range(len(xls)):
            write_col(worksheet, i, xls[i][0], xls[i][1], format_head,
                      format_cell)

        # **************************************************************************************************************
        # 把实时CQ详情也存一个sheet, 便于查找
        df.to_excel(writer, index=False, header=False, sheet_name="实时CQ单汇总")
        worksheet = writer.sheets['实时CQ单汇总']

        # **************************************************************************************************************
        # 处理本年累计, 先拉取

        cq_year_current = cq_post.pullCQYear(year_current)
        cq_year_current_dict = defaultdict(list)
        if cq_year_current['resultSetData']['rowData']:
            for item in cq_year_current['resultSetData']['rowData']:
                item['提出日期'] = strUtcTimeToLocalDate(item['提出日期'])
                item['计划更新UAT日期'] = strUtcTimeToLocalDate(item['计划更新UAT日期'])
                item['计划投产日期'] = strDateToDate(item['计划投产日期'])
                item['标准功能点数'] = float(item['标准功能点数'])
                cq_year_current_dict[item['owner']].append(item)

        df.to_excel(writer,
                    index=False,
                    header=False,
                    sheet_name=year_current + "年累计")
        worksheet = writer.sheets[year_current + "年累计"]

        # **************************************************************************************************************
        # 处理输入的年累计, 先拉取
        if year:

            cq_year = cq_post.pullCQYear(year)
            cq_year_dict = defaultdict(list)
            if cq_year['resultSetData']['rowData']:
                for item in cq_year['resultSetData']['rowData']:
                    item['提出日期'] = strUtcTimeToLocalDate(item['提出日期'])
                    item['计划更新UAT日期'] = strUtcTimeToLocalDate(
                        item['计划更新UAT日期'])
                    item['计划投产日期'] = strDateToDate(item['计划投产日期'])
                    item['标准功能点数'] = float(item['标准功能点数'])
                    cq_year_dict[item['owner']].append(item)

            df.to_excel(writer,
                        index=False,
                        header=False,
                        sheet_name=year + "年累计")
            worksheet = writer.sheets[year + '年累计']
def main():
    session, base_url = Session(), "http://xxx.xxx.xxx.xxx:xxxx"

    # 处理用户输入
    print("**********************************************************************************************************")
    print("用法:mobsapi -s 开始时间 -e 结束时间 -t 异常阙值 -n 登录名 -p 登录秘钥")
    print("时间格式:HH:MM(当天)或YYYY-MM-DD HH:MM(指定天),时间跨度建议一小时")
    print("经测试,非当日时间可以任意指定时间跨度,当日只能1~2个小时,多了服务器会报错")
    print("不指定时间,默认最近一个小时;异常阙值缺省为1.30;首次使用请指定登陆信息")
    print("**********************************************************************************************************")

    # 设置时间区间、异常阙值的缺省值
    now, one_hour_ago = datetime.now(), datetime.now() - timedelta(hours=1)
    time_start, time_end, threshold = one_hour_ago.strftime("%Y-%m-%d %H:%M"), now.strftime("%Y-%m-%d %H:%M"), 1.3

    def getpath(filename) -> str:
        if path[0].endswith(".zip"):
            return str.replace(path[0], "base_library.zip", filename)
        else:
            return path[0] + "\\" + filename

    login_name, login_password = None, None
    for i in range(len(argv)):
        if argv[i] == '-s' and argv[i + 1:]:
            if match(r"[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}", argv[i + 1]):
                time_start = argv[i + 1]
            elif match(r"[0-9]{2}:[0-9]{2}", argv[i + 1]):
                time_start = now.strftime("%Y-%m-%d") + " " + argv[i + 1]
            elif match(r"[0-9]{1}:[0-9]{2}", argv[i + 1]):
                time_start = now.strftime("%Y-%m-%d") + " " + "0" + argv[i + 1]
            else:
                print("无效的开始时间!")
                return
        elif argv[i] == '-e' and argv[i + 1:]:
            if match(r"[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}", argv[i + 1]):
                time_end = argv[i + 1]
            elif match(r"[0-9]{2}:[0-9]{2}", argv[i + 1]):
                time_end = now.strftime("%Y-%m-%d") + " " + argv[i + 1]
            elif match(r"[0-9]{1}:[0-9]{2}", argv[i + 1]):
                time_end = now.strftime("%Y-%m-%d") + " " + "0" + argv[i + 1]
            else:
                print("无效的截止时间!")
                return
        elif argv[i] == '-t' and argv[i + 1:]:
            if match(r"[0-9]+.?[0-9]*", argv[i + 1]):
                threshold = float(argv[i + 1])
            else:
                print("无效的异常阙值!")
                return
        elif argv[i] == '-n' and argv[i + 1:]:
            login_name = argv[i + 1]
        elif argv[i] == '-p' and argv[i + 1:]:
            # 智能运维平台的密码做了base64加密
            # 注意, b64encode的参数和输出都是byte类型, 即b'xxx'这种
            # byte和str之间要通过encode, decode转码
            login_password = b64encode(argv[i + 1].encode('utf-8')).decode('utf-8')

    url_login_info = getpath("login_info.npy")

    if login_name and login_password:
        numpy_save(url_login_info, [login_name, login_password])
    else:
        try:
            login_name, login_password = numpy_load(url_login_info)
        except:
            print("首次使用,请设置登录信息:-n 用户名 -p 密码!")
            return

    print("时间区间{0}~{1},异常阙值{2:.2f}......".format(time_start, time_end, threshold))

    # 登录获取授权码
    def login(login_name, login_password, base_url) -> str:
        login_url = base_url + "/imms/userLogin.do"
        params = {
            "REQ_HEAD": {"TRAN_PROCESS": "", "TRAN_ID": ""},
            "REQ_BODY": {
                "flag": "login",
                "logSegment": "OA",
                'loginName': login_name,
                "loginPassword": login_password,
                'loginState': "0",
                "logIP": "",
                "userAuthSession": "",
            }
        }

        try:
            # 注意这里上送参数用了json_dumps, 有时候不加
            # 如果从fiddler里看到的是json形式, 得加; 如果是a=xx&b=xx&c=xx, 直接送params
            return json_loads(session.post(login_url, json_dumps(params)).text)['RSP_BODY']['userAuthToken']
        except:
            print("登陆失败!请检查网络设置或重新配置登录信息:-n 用户名 -p 密码!")
            exit(1)

    print("登录中......")
    userAuthSession = login(login_name, login_password, base_url)

    # 查询, 经测验,2个小时勉强顶得住,再长服务器就返回错误了
    def pull(login_name, userAuthSession, base_url, time_start, time_end) -> list:
        pull_url = base_url + "/imms/qryDimStatisInfo.do"
        # 注意这里params是区分类型的,下面的4如果设置为字符型就查不回来数据
        params = {
            "REQ_HEAD": {"TRAN_PROCESS": "", "TRAN_ID": ""},
            "REQ_BODY": {
                "branchCode": "xxx",
                "loginName": login_name,
                "userAuthSession": userAuthSession,
                "pageSize": "8",
                "currentPage": "1",
                "startTime": time_start,
                "endTime": time_end,
                "firstDimensionsNo": "transCode",
                "secondDimensionsNo": "",
                "systemId": "MOBS",
                "timeType": 4,
                "type": "MX",

            }
        }
        try:
            return json_loads(session.post(pull_url, json_dumps(params)).text)['RSP_BODY']['dimStatisInfoVoList']
        except:
            print("拉取失败!可能时间区间设置过大,请再次尝试!")
            exit(1)

    print("数据拉取中......")
    dimStatisInfoVoList = pull(login_name, userAuthSession, base_url, time_start, time_end)

    if dimStatisInfoVoList:
        print("数据拉取成功......")
    else:
        print("拉取数据有误!可能服务器不稳定,请再次尝试!")
        return

    print("**********************************************************************************************************")

    def process(dimStatisInfoVoList):
        """
        除了从网上拉取,还可以通过pandans包的read_excel从excel里读取
        xls = read_excel(filename, skiprows=4, header=None, skipfooter=1, usecols=[0, 1, 2, 3, 5])
        xls = xls.values.tolist()
        """
        """
        拉取的数据列数很多,分析的时候不需要那么多, 没有到的后面看看还能不能干点别的
        xls = [[x['tradeDate'] + ' ' + x['tradeTime'], x['secondDimensionsNo'], x['configDesc'], int(x['tradeCount']),
                int(x['successTradeCount']), float(x['avgTime']), float(x['successRate']),
                float(x['systemSuccessRate']),
                float(x['tradeSuccessRate'])] for x in dimStatisInfoVoList]
        """
        xls = [[x['tradeDate'] + ' ' + x['tradeTime'], x['secondDimensionsNo'], x['configDesc'],
                int(x['tradeCount']), float(x['avgTime'])] for x in dimStatisInfoVoList]
        # 分别以时间和日期为key值构造两个dict,便于后面使用
        dict_time, dict_api = defaultdict(list), defaultdict(list)
        for item in xls:
            dict_time[item[0]].append(item[1:])
            dict_api[item[1]].append([item[0]] + item[2:])

        # 计算几个统计数据-时间跨度、所有接口的总调用次数、总耗时、总平均耗时
        time_range = len(set([item[0] for item in xls]))
        count_all = sum([item[3] for item in xls])
        duras_all = sum([item[3] * item[4] for item in xls])
        duras_avg_all = duras_all / count_all if count_all > 0 else 0

        # 计算各个时间维度的统计信息,包括本分钟所有接口调用次数、总耗时、平均耗时
        dict_stati_time = defaultdict(list)
        for key, value in dict_time.items():
            count_time = sum([item[2] for item in value])
            duras_time = sum([item[2] * item[3] for item in value])
            duras_avg_time = duras_time / count_time if count_time > 0 else duras_avg_all
            dict_stati_time[key].extend([count_time, duras_time, duras_avg_time])

        # 计算各个接口的统计信息,接口名、总调用次数、总耗时、平均耗时、标准差
        dict_stati_api = defaultdict(list)
        for key, value in dict_api.items():

            count_api = sum([item[2] for item in value])
            duras_api = sum([item[2] * item[3] for item in value])
            duras_avg_api = duras_api / count_api if count_api > 0 else duras_avg_all

            # 对均耗时归一化之后再去算无偏标准差
            duras_l = [item[3] for item in value]
            # 归一化算标准差的时候需要兼容min-max相同的情况,这种一般是极低频接口,直接设置标准差为0
            if max(duras_l) == min(duras_l):
                std_api = 0
            else:
                std_api = std([(x - min(duras_l)) / (max(duras_l) - min(duras_l)) for x in duras_l], ddof=1)

            dict_stati_api[key].extend([value[0][1], count_api, duras_api, duras_avg_api, std_api])

        # 打印按照各个维度排名的前十的接口编号和名字, 注意过滤低频的(至少time_range*60次)
        l_s_api = [item for item in list(dict_stati_api.items()) if item[1][1] >= time_range * 60]
        print("1. 每分钟交易量排名前十的高频接口(平均每秒至少调用一次)分别为:")
        l_s_api.sort(key=lambda x: x[1][1], reverse=True)
        print("{0:<8}{1:<8}{2:<8}{3:<8}{4:<}".format("接口编号", "分均频次", "平均耗时", "波动程度", "接口名称"))
        for i in range(10):
            list_print = [l_s_api[i][0], l_s_api[i][1][1] / time_range, l_s_api[i][1][3], l_s_api[i][1][4],
                          l_s_api[i][1][0]]
            print("{0:<12}{1:<12.0f}{2:<12.2f}{3:<12.2f}{4:<}".format(*list_print))
        print(
            "**********************************************************************************************************")

        print("2. 平均调用耗时排名前十的高频接口(平均每秒至少调用一次)分别为:")
        l_s_api.sort(key=lambda x: x[1][3], reverse=True)
        print("{0:<8}{1:<8}{2:<8}{3:<8}{4:<}".format("接口编号", "分均频次", "平均耗时", "波动程度", "接口名称"))
        for i in range(10):
            list_print = [l_s_api[i][0], l_s_api[i][1][1] / time_range, l_s_api[i][1][3], l_s_api[i][1][4],
                          l_s_api[i][1][0]]
            print("{0:<12}{1:<12.0f}{2:<12.2f}{3:<12.2f}{4:<}".format(*list_print))
        print(
            "**********************************************************************************************************")

        print("3. 平均耗时波动排名前十的高频接口(平均每秒至少调用一次)分别为:")
        l_s_api.sort(key=lambda x: x[1][4], reverse=True)
        print("{0:<8}{1:<8}{2:<8}{3:<8}{4:<}".format("接口编号", "分均频次", "平均耗时", "波动程度", "接口名称"))
        for i in range(10):
            list_print = [l_s_api[i][0], l_s_api[i][1][1] / time_range, l_s_api[i][1][3], l_s_api[i][1][4],
                          l_s_api[i][1][0]]
            print("{0:<12}{1:<12.0f}{2:<12.2f}{3:<12.2f}{4:<}".format(*list_print))
        print(
            "**********************************************************************************************************")

        # 本来想开个子线程或者子进程去画图,但是matplot好像不支持
        # 只能换一种思路, 开子线程去让输入待分析时间点, 在主进程里画图
        class AnalysePeak(Thread):
            def __init__(self, dict_time, dict_stati_api, duras_avg_all):
                Thread.__init__(self)
                self.dic_time = dict_time
                self.dict_stati_api = dict_stati_api
                self.duras_avg_all = duras_avg_all

            def run(self):
                print(
                    "**********************************************************************************************************")

                while True:

                    time = input("请输入待分析时间点YYYY-MM-DD HH:MM:").strip()
                    if time not in self.dic_time:
                        print("无效的时间输入!")
                    else:
                        list_apis = dict_time[time].copy()
                        for item in list_apis:
                            # 计算异常点贡献值,计算公式为max(本时点平均耗时-总平均耗时, 0)*本时点调用次数
                            # 某个特定接口的总平均耗时可能不存在,如果没有,取全局平均耗时duras_avg_all
                            if item[0] in self.dict_stati_api:
                                duras_avg_api = self.dict_stati_api[item[0]][3]
                            else:
                                duras_avg_api = self.duras_avg_all
                            contribute_value = max(0, item[2] * (item[3] - duras_avg_api))
                            item.extend([duras_avg_api, contribute_value])

                        # 算出各个接口的异常贡献后,取和,然后算比例
                        # 注意处理极个别特殊情况,总平均耗时很高,个别低的时点可能所有接口都很快,contribute_value都为0
                        contribute_value_all = sum([item[5] for item in list_apis])
                        for item in list_apis:
                            contribute_percent = item[5] / contribute_value_all if contribute_value_all != 0 else 0
                            item.append(contribute_percent)

                        list_apis.sort(key=lambda x: x[6], reverse=True)
                        print(
                            "**********************************************************************************************************")
                        print("时点{0}的接口耗时异常增加主要由以下接口导致(打印前二十个):".format(time))
                        head = ["接口编号", "时点频次", "时点均耗", "正常均耗", "贡献量", "贡献度", "接口名称"]
                        print("{0:<8}{1:<8}{2:<8}{3:<8}{4:<9}{5:<9}{6:<}".format(*head))
                        for i in range(20):
                            item = list_apis[i]
                            list_print = [item[0], item[2], item[3], item[4], item[5], item[6], item[1]]
                            print("{0:<12}{1:<12.0f}{2:<12.2f}{3:<12.2f}{4:<12.0f}{5:<12.2%}{6:<}".format(*list_print))
                        print(
                            "**********************************************************************************************************")

        # 这里有个问题, dict_stati_api和duras_avg_all如果取分析的这个时间段的,有可能异常事件很长
        # 每个接口的平均耗时和总平均耗时已经被拉高了,比如某个接口平时平均30ms,分析时段的平均是40,分析时点50, 贡献量算20还是10?
        # 那么把历史的存了吧,但凡在合理范围内的总平均耗时,则把新的dict_stati_api和duras_avg_all存起来下次备用

        # with语句虽然会处理文件读取过程中的异常,但是open异常不会处理,所以外层还得加个try
        # 这里用pickle而不是np,因为np老乱转格式,而pickle存的时候是什么取的时候就是什么
        # 这里获取执行文件所在路径用sys.path, 其它的方法都有问题,不能兼顾脚本、exe、环境变量各种情况

        url_dict_stati_api, url_duras_avg_all = getpath("dict_stati_api.pickle"), getpath("duras_avg_all.pickle")

        try:
            with open(url_dict_stati_api, "rb") as file:
                dict_stati_api_old = pickle_load(file)
            with open(url_duras_avg_all, "rb") as file:
                duras_avg_all_old = pickle_load(file)
        except:
            dict_stati_api_old = dict_stati_api
            duras_avg_all_old = duras_avg_all

        # 比如已存的是60ms,则<=66ms的都认为是正常水平,存起来
        # 另外,过短的不建议存,最好只存时间跨度大于55的
        if time_range >= 55 and duras_avg_all <= duras_avg_all_old * 1.1:
            with open(url_dict_stati_api, "wb") as file:
                pickle_dump(dict_stati_api, file)
            with open(url_duras_avg_all, "wb") as file:
                pickle_dump(duras_avg_all, file)

        threadAP = AnalysePeak(dict_time, dict_stati_api_old, duras_avg_all_old)
        threadAP.setDaemon(True)

        # 这里先不开子线程,后面再输出点东西,在画图前开就行

        # 主进程继续画图,三个参数分别为x,y,标注阙值
        def draw(x, y, threshold, threadAP):

            plt.figure(figsize=(12, 12 * 0.618))
            plt.xlabel("Time")
            plt.ylabel("Average Time-Consuming(ms)")

            plt.plot(y)
            # axis用于定义横坐标、纵坐标刻度范围
            plt.axis([0, len(x), 0, max(y) * 1.1])

            # 横坐标刻度保持24个以内
            step = len(x)//24+1
            range_x, mark_x = range(0, len(x), step), [x[i][11:] for i in range(0, len(x), step)]

            plt.xticks(range_x, mark_x)

            # 对于折线图的上顶点加标注,超过一定阙值, 且是尖点才加
            for i in range(len(x)):
                if y[i] >= threshold and (i == 0 or i == len(x)-1 or (y[i] >= y[i-1] and y[i] >= y[i+1])):
                    plt.annotate(x[i][11:], xy=(i, y[i]), xytext=(-20, 10), textcoords="offset pixels", color="red")

            # 主进程卡住之前先把子线程启动了
            threadAP.start()
            url_img = getpath("figure.png")
            plt.savefig(url_img)
            plt.show()

        l_s_time = list(dict_stati_time.items())

        # 先按照耗时由大到小排序,打印平均耗时最高的5个时间点
        l_s_time.sort(key=lambda x: x[1][2], reverse=True)
        print("4. 平均耗时排名前五的时点为(时段总平均耗时{:.2f}ms),参考平均耗时{:.2f}ms:".format(duras_avg_all, duras_avg_all_old))
        print("{0:<22}{1:<8}".format("时点", "平均耗时"))
        for i in range(5):
            print("{0:<24}{1:<12.2f}".format(l_s_time[i][0], l_s_time[i][1][2]))
        print(
            "**********************************************************************************************************")

        # 画图,更直观地展示接口平均耗时随时点的变化
        l_s_time.sort(key=lambda x: x[0])
        x, y = [x[0] for x in l_s_time], [x[1][2] for x in l_s_time]
        # 先设置阙值,默认1.3,即平均100ms,到了130ms就认为高了
        draw(x, y, duras_avg_all * threshold, threadAP)

    process(dimStatisInfoVoList)
    def load(self):
        # some parameters
        # NOTE: for reference, inside each h5 file there is
        #   [u'Amatrices', u'Hi', u'bboxes', u'h5_path']
        # get the total number of shots using worker 0
        if rank == 0:
            self.time_load_start = time.time()
            print("I am root. I am calculating total number of shots")
            h5s = [h5py_File(f, "r") for f in self.fnames]
            Nshots_per_file = [h["h5_path"].shape[0] for h in h5s]
            Nshots_tot = sum(Nshots_per_file)
            print("I am root. Total number of shots is %d" % Nshots_tot)

            print("I am root. I will divide shots amongst workers.")
            shot_tuples = []
            roi_per = []
            for i_f, fname in enumerate(self.fnames):
                fidx_shotidx = [(i_f, i_shot) for i_shot in range(Nshots_per_file[i_f])]
                shot_tuples += fidx_shotidx

                # store the number of usable roi per shot in order to divide shots amongst ranks equally
                roi_per += [sum(h5s[i_f]["bboxes"]["%s%d" % (args.keeperstag, i_shot)][()])
                            for i_shot in range(Nshots_per_file[i_f])]

            from numpy import array_split
            from numpy.random import permutation
            print ("I am root. Number of uniques = %d" % len(set(shot_tuples)))

            # divide the array into chunks of roughly equal sum (total number of ROI)
            if args.partition and args.restartfile is None and args.xinitfile is None:
                diff = np.inf
                roi_per = np.array(roi_per)
                tstart = time.time()
                best_order = range(len(roi_per))
                print("Partitioning for better load balancing across ranks.. ")
                while 1:
                    order = permutation(len(roi_per))
                    res = [sum(a) for a in np.array_split(roi_per[order], size)]
                    new_diff = max(res) - min(res)
                    t_elapsed = time.time() - tstart
                    t_remain = args.partitiontime - t_elapsed
                    if new_diff < diff:
                        diff = new_diff
                        best_order = order.copy()
                        print("Best diff=%d, Parition time remaining: %.3f seconds" % (diff, t_remain))
                    if t_elapsed > args.partitiontime:
                        break
                shot_tuples = [shot_tuples[i] for i in best_order]

            elif args.partition and args.restartfile is not None:
                print ("Warning: skipping partitioning time to use shot mapping as laid out in restart file dir")
            else:
                print ("Proceeding without partitioning")

            # optional to divide into a sub group
            shot_tuples = array_split(shot_tuples, args.ngroups)[args.groupId]
            shots_for_rank = array_split(shot_tuples, size)
            import os  # FIXME, I thought I was imported already!
            if args.outdir is not None:  # save for a fast restart (shot order is important!)
                np.save(os.path.join(args.outdir, "shots_for_rank"), shots_for_rank)
            if args.restartfile is not None:
                # the directory containing the restart file should have a shots for rank file
                dirname = os.path.dirname(args.restartfile)
                print ("Loading shot mapping from dir %s" % dirname)
                shots_for_rank = np.load(os.path.join(dirname, "shots_for_rank.npy"))
                # propagate the shots for rank file...
                if args.outdir is not None:
                    np.save(os.path.join(args.outdir, "shots_for_rank"), shots_for_rank)
            if args.xinitfile is not None:
                # the directory containing the restart file should have a shots for rank file
                dirname = os.path.dirname(args.xinitfile)
                print ("Loading shot mapping from dir %s" % dirname)
                shots_for_rank = np.load(os.path.join(dirname, "shots_for_rank.npy"))
                # propagate the shots for rank file...
                if args.outdir is not None:
                    np.save(os.path.join(args.outdir, "shots_for_rank"), shots_for_rank)

            # close the open h5s..
            for h in h5s:
                h.close()

        else:
            Nshots_tot = None
            shots_for_rank = None
            h5s = None

        # Nshots_tot = comm.bcast( Nshots_tot, root=0)
        if has_mpi:
            shots_for_rank = comm.bcast(shots_for_rank, root=0)
        # h5s = comm.bcast( h5s, root=0)  # pull in the open hdf5 files

        my_shots = shots_for_rank[rank]
        if self.Nload is not None:
            start = 0
            if args.loadstart is not None:
                start = args.loadstart
            my_shots = my_shots[start: start + self.Nload]
        print("Rank %d: I will load %d shots, first shot: %s, last shot: %s"
              % (comm.rank, len(my_shots), my_shots[0], my_shots[-1]))

        # open the unique filenames for this rank
        # TODO: check max allowed pointers to open hdf5 file
        import h5py
        my_unique_fids = set([fidx for fidx, _ in my_shots])
        self.my_open_files = {fidx: h5py_File(self.fnames[fidx], "r") for fidx in my_unique_fids}
        # for fidx in my_unique_fids:
        #    fpath = self.fnames[fidx]
        #    if args.imgdirname is not None:
        #        fpath = fpath.split("/kaladin/")[1]
        #        fpath = os.path.join(args.imgdirname, fpath)
        #    self.my_open_files[fidx] = h5py.File(fpath, "r")
        Ntot = 0

        for img_num, (fname_idx, shot_idx) in enumerate(my_shots):
            h = self.my_open_files[fname_idx]

            # load the dxtbx image data directly:
            npz_path = h["h5_path"][shot_idx]

            if args.imgdirname is not None:
                import os
                npz_path = npz_path.split("/kaladin/")[1]
                npz_path = os.path.join(args.imgdirname, npz_path)

            if args.noiseless:
                noiseless_path = npz_path.replace(".npz", ".noiseless.npz")
                img_handle = numpy_load(noiseless_path)
            else:
                img_handle = numpy_load(npz_path)

            img = img_handle["img"]

            if len(img.shape) == 2:  # if single panel
                img = array([img])

            B = beam_from_dict(img_handle["beam"][()])

            log_init_crystal_scale = 0  # default
            if args.usepreoptscale:
                log_init_crystal_scale = h["crystal_scale_%s" % args.preopttag][shot_idx]
            # get the indexed crystal Amatrix
            Amat = h["Amatrices"][shot_idx]
            if args.usepreoptAmat:
                Amat = h["Amatrices_%s" % args.preopttag][shot_idx]
            amat_elems = list(sqr(Amat).inverse().elems)
            # real space basis vectors:
            a_real = amat_elems[:3]
            b_real = amat_elems[3:6]
            c_real = amat_elems[6:]

            # dxtbx indexed crystal model
            C = Crystal(a_real, b_real, c_real, "P43212")

            # change basis here ? Or maybe just average a/b
            a, b, c, _, _, _ = C.get_unit_cell().parameters()
            a_init = .5 * (a + b)
            c_init = c

            # shoe boxes where we expect spots
            bbox_dset = h["bboxes"]["shot%d" % shot_idx]
            n_bboxes_total = bbox_dset.shape[0]
            # is the shoe box within the resolution ring and does it have significant SNR (see filter_bboxes.py)
            is_a_keeper = h["bboxes"]["%s%d" % (args.keeperstag, shot_idx)][()]
            # tilt plane to the background pixels in the shoe boxes
            tilt_abc_dset = h["tilt_abc"]["shot%d" % shot_idx]
            # miller indices (not yet reduced by symm equivs)
            Hi_dset = h["Hi"]["shot%d" % shot_idx]
            try:
                panel_ids_dset = h["panel_ids"]["shot%d" % shot_idx]
                has_panels = True
            except KeyError:
                has_panels = False

            # apply the filters:
            bboxes = [bbox_dset[i_bb] for i_bb in range(n_bboxes_total) if is_a_keeper[i_bb]]
            tilt_abc = [tilt_abc_dset[i_bb] for i_bb in range(n_bboxes_total) if is_a_keeper[i_bb]]
            Hi = [tuple(Hi_dset[i_bb]) for i_bb in range(n_bboxes_total) if is_a_keeper[i_bb]]
            proc_file_idx = [i_bb for i_bb in range(n_bboxes_total) if is_a_keeper[i_bb]]

            if has_panels:
                panel_ids = [panel_ids_dset[i_bb] for i_bb in range(n_bboxes_total) if is_a_keeper[i_bb]]
            else:
                panel_ids = [0] * len(tilt_abc)

            # how many pixels do we have
            tot_pix = [(j2 - j1) * (i2 - i1) for i1, i2, j1, j2 in bboxes]
            Ntot += sum(tot_pix)

            # load some ground truth data from the simulation dumps (e.g. spectrum)
            # h5_fname = h["h5_path"][shot_idx].replace(".npz", "")
            h5_fname = npz_path.replace(".npz", "")
            if args.character is not None:
                h5_fname = h5_fname.replace("rock", args.character)
            if args.testmode2:
                h5_fname = npz_path.split(".npz")[0]
            data = h5py_File(h5_fname, "r")

            xtal_scale_truth = data["spot_scale"][()]
            tru = sqr(data["crystalA"][()]).inverse().elems
            a_tru = tru[:3]
            b_tru = tru[3:6]
            c_tru = tru[6:]
            C_tru = Crystal(a_tru, b_tru, c_tru, "P43212")

            fluxes = data["spectrum"][()]
            es = data["exposure_s"][()]

            # comm.Barrier()
            # exit()
            fluxes *= es  # multiply by the exposure time
            # TODO: wavelens should come from the imageset file itself
            if "wavelengths" in data.keys():
                wavelens = data["wavelengths"][()]
            else:
                raise KeyError("Wavelengths missing from hdf5 data")
                #from cxid9114.parameters import WAVELEN_HIGH
                #wavelens = [WAVELEN_HIGH]

            spectrum = zip(wavelens, fluxes)
            # dont simulate when there are no photons!
            spectrum = [(wave, flux) for wave, flux in spectrum if flux > self.flux_min]

            if args.forcemono:
                spectrum = [(B.get_wavelength(), sum(fluxes))]

            # make a unit cell manager that the refiner will use to track the B-matrix
            aa, _, cc, _, _, _ = C_tru.get_unit_cell().parameters()
            ucell_man = TetragonalManager(a=a_init, c=c_init)

            # create the sim_data instance that the refiner will use to run diffBragg
            # create a nanoBragg crystal
            self.Fhkl_obs = open_flex(args.Fobs).as_amplitude_array()
            self.Fhkl_ref = None
            if args.Fref is not None:
                self.Fhkl_ref = open_flex(
                    args.Fref).as_amplitude_array()  # this reference miller array is used to track CC and R-factor
            if img_num == 0:  # only initialize the simulator after loading the first image
                self.initialize_simulator(C, B, spectrum, self.Fhkl_obs)

            # map the miller array to ASU
            Hi_asu = map_hkl_list(Hi, self.anomalous_flag, self.symbol)

            # copy the image as photons (NOTE: Dont forget to ditch its references!)
            img_in_photons = (img / args.gainval).astype('float32')

            # Here, takeout from the image only whats necessary to perform refinement
            # first filter the spot rois so they dont occur exactly at the boundary of the image (inclusive range in nB)
            assert len(img_in_photons.shape) == 3  # sanity
            nslow, nfast = img_in_photons[0].shape
            bboxes = array(bboxes)
            for i_bbox, (_, x2, _, y2) in enumerate(bboxes):
                if x2 == nfast:
                    bboxes[i_bbox][1] = x2 - 1  # update roi_xmax
                if y2 == nslow:
                    bboxes[i_bbox][3] = y2 - 1  # update roi_ymax
            # now cache the roi in nanoBragg format ((x1,x2), (y1,y1))
            # and also cache the pixels and the coordinates

            nanoBragg_rois = []  # special nanoBragg format
            xrel, yrel, roi_img = [], [], []
            for i_roi, (x1, x2, y1, y2) in enumerate(bboxes):
                nanoBragg_rois.append(((x1, x2), (y1, y2)))
                yr, xr = np_indices((y2 - y1 + 1, x2 - x1 + 1))
                xrel.append(xr)
                yrel.append(yr)
                pid = panel_ids[i_roi]
                roi_img.append(img_in_photons[pid, y1:y2 + 1, x1:x2 + 1])

            # make sure to clear that damn memory
            img = None
            img_in_photons = None
            del img  # not sure if needed here..
            del img_in_photons

            # peak at the memory usage of this rank
            # mem = getrusage(RUSAGE_SELF).ru_maxrss  # peak mem usage in KB
            # mem = mem / 1e6  # convert to GB
            mem = self._usage()

            # print "RANK %d: %.2g total pixels in %d/%d bboxes (file %d / %d); MemUsg=%2.2g GB" \
            #      % (rank, Ntot, len(bboxes), n_bboxes_total,  img_num +1, len(my_shots), mem)
            self.all_pix += Ntot

            # accumulate per-shot information
            self.global_image_id[img_num] = None  # TODO
            self.all_spot_roi[img_num] = bboxes
            self.all_abc_inits[img_num] = tilt_abc
            self.all_panel_ids[img_num] = panel_ids
            self.all_ucell_mans[img_num] = ucell_man
            self.all_spectra[img_num] = spectrum
            self.all_crystal_models[img_num] = C
            self.log_of_init_crystal_scales[
                img_num] = log_init_crystal_scale  # these should be the log of the initial crystal scale
            self.all_crystal_scales[img_num] = xtal_scale_truth
            self.all_crystal_GT[img_num] = C_tru
            self.all_xrel[img_num] = xrel
            self.all_yrel[img_num] = yrel
            self.all_nanoBragg_rois[img_num] = nanoBragg_rois
            self.all_roi_imgs[img_num] = roi_img
            self.all_fnames[img_num] = npz_path
            self.all_proc_fnames[img_num] = h.filename
            self.all_Hi[img_num] = Hi
            self.all_Hi_asu[img_num] = Hi_asu
            self.all_proc_idx[img_num] = proc_file_idx
            self.all_shot_idx[img_num] = shot_idx  # this is the index of the shot in the process*h5 file
            # NOTE all originZ for each panel are the same in the simulated data.. Not necessarily true for real data
            shot_originZ = self.SIM.detector[0].get_origin()[2]
            self.shot_originZ_init[img_num] = shot_originZ
            print(img_num)

        for h in self.my_open_files.values():
            h.close()
        stderr.write('there is an error in transformation file "{0:s}" (it should be valid json)\n'.format(args.trans))
        stderr.write("{0:s}\n".format(str(err)))
        exit(2)
else:
    stdout.write('transformation file "{0:s}" doesn\'t exist; using empty transformations\n'.format(args.trans))
    config = {"input_vars": None}
    transformations = OrderedDict()

if args.trans_out is None:
    args.trans_out = args.trans

if not args.data_in:
    stderr.write("please provide an input data file (--help for info)\n")
    exit(3)
try:
    raw = numpy_load(args.data_in)
except FileNotFoundError:
    stderr.write('input file "{0:s}" could not be found\n'.format(args.data_in))
    exit(4)
except OSError:
    stderr.write('there is an error in input file "{0:s}" (it should be a valid npy file)\n'.format(args.data_in))
    exit(5)

if args.transpose:
    raw = raw.T

if args.limit:
    stdout.write("--limit loads the whole array before slicing it (no easy way to partially load npy files)\n")
    raw = raw[:, : args.limit]

if args.var_names: