예제 #1
0
파일: pick.py 프로젝트: zhaotf16/cnpick
def pick(opt):
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus_str
    opt.debug = max(opt.debug, 1)
    Detector = detector_factory[opt.task]
    detector = Detector(opt)

    if os.path.isdir(opt.data):
        image_names = []
        ls = os.listdir(opt.data)
        for file_name in sorted(ls):
            ext = file_name[file_name.rfind('.') + 1:].lower()
            if ext in image_ext:
                image_names.append(os.path.join(opt.data, file_name))
    else:
        image_names = [opt.data]
    #visual_path = os.path.join(opt.demo, 'visual_1024/')
    #if not os.path.exists(visual_path):
    #  os.makedirs(visual_path)
    if opt.data_type == 'mrc':
        mrc_thi = []
        for (image_name) in image_names:
            with open(image_name, "rb") as f:
                content = f.read()
            data, header, _ = parse(content=content)
            print('downsampling', image_name, '...')
            print(int(data.shape[0] / data.shape[1] * 1024))
            data = downsample_with_size(
                data, 1024, int(data.shape[0] / data.shape[1] * 1024))
            data = quantize(data)
            data = cv2.equalizeHist(data)
            data = cv2.merge([data, data, data])
            name = image_name.split('/')[-1].replace('.mrc', '')
            thi_name = image_name.split('/')[-1].replace('.mrc', '.thi')
            mrc_thi.append((image_name[-1], thi_name))

            ret = detector.run(data, header, name)
            time_str = ''
            for stat in time_stats:
                time_str = time_str + '{} {:.3f}s |'.format(stat, ret[stat])
            print(time_str)
        with open('merge_results.thi', "w") as f:
            f.write(
                '[Micrograph Particle coordinate:\n #0:MICROGRAPH_PATH    STRING\n #1:PARTICLE_PATH    STRING]\n'
            )
            for item in mrc_thi:
                f.write('@%s\t@%s\n' % (item[0], item[1]))
    elif opt.data_type == 'png':
        for (image_name) in image_names:
            ret = detector.run(image_name, None, image_name)
            time_str = ''
            for stat in time_stats:
                time_str = time_str + '{} {:.3f}s |'.format(stat, ret[stat])
            print(time_str)
예제 #2
0
def load_and_downsample(path, target_size):
    with open(path, "rb") as f:
        content = f.read()
    data, header, _ = mrc.parse(content=content)
    name = path.split('/')[-1].split('.')[0]
    #averge frame
    if header[2] > 1:
        avg_mrc = np.zeros_like(data, data[0, ...])
        for j in range(header[2]):
            avg_mrc += data[j, ...]
        avg_mrc /= header[2]
        data = avg_mrc
    data = mrc.downsample_with_size(data, target_size,
                                    int(data.shape[1] / data.shape[0] * 1024))
    return mrc.MrcData(name, data, header)
예제 #3
0
def downsample(inputs, use_factor=False, para1=None, para2=None):
    #This method executes a downsampling on the mrc
    #Inputs is a list of MrcData
    #If use_factor is True, para1 represents factor and para2 represents shape
    #Else para1 and para2 represents the target size(y, x)

    for i in range(len(inputs)):
        if use_factor:
            print("Processing %s ..." % (inputs[i].name))
            inputs[i].data = mrc.downsample_with_factor(inputs[i].data,
                                                        factor=para1,
                                                        shape=para2)
        else:
            print("Prcocessing %s ..." % (inputs[i].name))
            inputs[i].data = mrc.downsample_with_size(inputs[i].data,
                                                      size1=para1,
                                                      size2=para2)
    return inputs