Пример #1
0
def plot_segmentation_samples(output_folder, plot_well=True):
    samples = []
    for root in PRED_SAVE_DIRs:
        all_pred_names = pickle.load(open(os.path.join(root, 'cla.pkl'),
                                          'rb'))['pred_names']
        plot_samples = []
        if plot_well:
            wells = list(set([get_well(n) for n in all_pred_names]))
            plot_wells = np.random.choice(np.arange(len(wells)), (10, ),
                                          replace=False)
            plot_wells = [wells[i] for i in plot_wells]
            plot_samples = [
                n for n in all_pred_names if get_well(n) in plot_wells
            ]
        else:
            plot_samples = np.random.choice(np.arange(len(all_pred_names)),
                                            (100, ),
                                            replace=False)
            plot_samples = [all_pred_names[i] for i in plot_samples]

        fs = [f for f in os.listdir(root) if f.startswith('seg_')]
        for f in sorted(fs, key=lambda x: int(x.split('_')[1].split('.')[0])):
            dat = pickle.load(open(os.path.join(root, f), 'rb'))
            seg_preds = dat['seg_preds']
            seg_trues = dat['seg_trues']
            seg_ws = dat['seg_ws']
            pred_names = dat['pred_names']
            for s_pred, s_true, s_w, name in zip(np.concatenate(seg_preds, 0),
                                                 np.concatenate(seg_trues, 0),
                                                 np.concatenate(seg_ws, 0),
                                                 pred_names):

                if name in plot_samples and not np.allclose(s_w, 0):
                    samples.append((s_pred, s_true, s_w, name))

    samples = sorted(samples, key=lambda x: get_well(x[3]))

    os.makedirs(output_folder, exist_ok=True)
    for i, (s_pred, s_true, s_w, name) in enumerate(samples):
        if isinstance(name, list) or isinstance(name, tuple):
            name = name[0]
        _s_pred = s_pred[s_w > 0]
        _s_true = s_true[s_w > 0]
        pr = pearsonr(_s_pred, _s_true)[0]

        title = str(get_identifier(name)) + '\n' + 'pearson-r: %.3f' % pr
        file_name = str(get_well(name)) + '__D%s.png' % get_identifier(name)[2]

        plt.clf()
        plt.subplot(2, 1, 1)
        plt.imshow(s_pred, vmin=0., vmax=3.)
        plt.axis('off')
        plt.title(title)
        plt.subplot(2, 1, 2)
        plt.imshow(s_true, vmin=0., vmax=3.)
        plt.axis('off')
        plt.savefig(os.path.join(output_folder, file_name),
                    dpi=300,
                    bbox_inches='tight')
Пример #2
0
 def pair_identifier(p):
     id_from = get_identifier(self.names[p[0]])
     id_to = get_identifier(self.names[p[1]])
     assert id_from[:2] == id_to[:2]
     assert id_from[3:] == id_to[3:]
     pair_id = tuple([
         id_from[0], id_from[1],
         int(id_from[2]),
         int(id_to[2]), id_from[3], id_from[4]
     ])
     return pair_id
Пример #3
0
def plot_sample_labels(pairs,
                       save_dir='.',
                       raw_label_preprocess=lambda x: x,
                       linear_align=False):
    all_views = list(
        set(get_identifier(p[0])[3:] for p in pairs if p[0] is not None))
    selected_views = set([
        all_views[i]
        for i in np.random.choice(np.arange(len(all_views)), (20, ),
                                  replace=False)
    ])

    data = {}
    for view in selected_views:
        view_pairs = [
            p for p in pairs if p[0] is not None and p[1] is not None
            and get_identifier(p[0])[3:] == view
        ]
        print(view)
        for p in view_pairs:
            identifier = get_identifier(p[0])
            day = str(identifier[2])
            name = '_'.join(identifier)
            save_path = os.path.join(save_dir, day, name)
            os.makedirs(os.path.join(save_dir, day), exist_ok=True)

            pair_dat = load_image_pair(p)
            pair_dat = [pair_dat[0], raw_label_preprocess(pair_dat[1])]
            data[identifier] = pair_dat

            plt.clf()
            plt.imshow(pair_dat[1].astype(float))
            plt.savefig(save_path + '_fl.png')

            position_code = identifier[-1]
            if linear_align and position_code in ['1', '3', '7', '9']:
                mask = generate_mask(pair_dat)
            else:
                mask = np.ones_like(pair_dat[0])

            discrete_y = generate_fluorescence_labels(pair_dat, mask)
            if discrete_y is None:
                print("ERROR in labeling %s" % name)
                continue

            plt.clf()
            plt.imshow(discrete_y.astype(float), vmin=0, vmax=2)
            plt.savefig(save_path + '_fl_discrete.png')
    with open(os.path.join(save_dir, "data.pkl"), "wb") as f:
        pickle.dump(data, f)
    return
Пример #4
0
def check_valid_for_0_to_inf_training(i):
    try:
        X, y, w, name = base_dataset.load_ind(i)
    except Exception as e:
        print(e)
        print("ISSUE %d" % i)
        return False, False

    if X is None:
        return False, False

    source_flag = True
    target_flag = True

    # "Source data from day 3 to 12, Target data from day 8 onwards"
    if int(get_identifier(name)[2]) < 8:
        target_flag = False
    if (y is None) or \
       (w is None) or \
       (base_dataset.classify_y[i] is None) or \
       (base_dataset.classify_w[i] is None):
        target_flag = False
    if np.all(w == 0) or (base_dataset.classify_w[i] == 0):
        target_flag = False
    return source_flag, target_flag
Пример #5
0
def filter_for_0_predictor(batch):
    X, labels, batch_names = batch
    inds = []
    for i, name in enumerate(batch_names):
        if isinstance(name, list) or isinstance(name, tuple):
            name = name[0]
        if 7 <= int(get_identifier(name)[2]) <= 20:
            inds.append(i)
    return np.array(inds)
Пример #6
0
def PREPROCESS_FILTER(pair, well_setting='96well-3'):
    # Remove samples without phase contrast
    if pair[0] is None:
        return False
    # Remove samples with inconsistent id
    if pair[1] is not None and get_identifier(pair[0]) != get_identifier(
            pair[1]):
        return False
    # Remove corner samples
    if well_setting == '6well-15':
        if get_identifier(pair[0])[-1] in \
            ['1', '2', '16', '14', '15', '30', '196', '211', '212', '210', '224', '225']:
            return False
    elif well_setting == '6well-14':
        if get_identifier(pair[0])[-1] in \
            ['1', '2', '15', '13', '14', '28', '169', '183', '184', '182', '195', '196']:
            return False
    elif well_setting == '96well-3':
        if get_identifier(pair[0])[-1] in \
            ['1', '3', '7', '9']:
            return False
    return True
Пример #7
0
    def get_all_pairs(self, time_interval=[1, 3]):
        valid_pcs = {}
        valid_fls = {}
        for k in sorted(self.names.keys()):
            x, y, w, n = self.load_ind(k)
            assert isinstance(n, str)
            if x is not None:
                valid_pcs[k] = get_identifier(n)
            if y is not None:
                valid_fls[k] = get_identifier(n)

        fls_reverse_mapping = {v: k for k, v in valid_fls.items()}

        valid_pairs = []
        for ind_i in sorted(valid_pcs):
            d = valid_pcs[ind_i]
            if d[2] == 'unknown':
                continue
            for t in range(time_interval[0], time_interval[1] + 1):
                new_d = (d[0], d[1], str(int(d[2]) + t), d[3], d[4])
                if new_d in fls_reverse_mapping:
                    ind_j = fls_reverse_mapping[new_d]
                    valid_pairs.append((ind_i, ind_j))
        return valid_pairs
Пример #8
0
def check_valid_for_0_to_0_training(i):
    try:
        X, y, w, name = base_dataset.load_ind(i)
    except Exception as e:
        print(e)
        print("ISSUE %d" % i)
        return False
    # "Use data after day 7 onwards"
    if int(get_identifier(name)[2]) < 7:
        return False
    if (X is None) or \
       (y is None) or \
       (w is None) or \
       (base_dataset.classify_y[i] is None) or \
       (base_dataset.classify_w[i] is None):
        return False
    if np.all(w == 0) or (base_dataset.classify_w[i] == 0):
        return False
    return True
Пример #9
0
def get_well(name):
    if isinstance(name, list) or isinstance(name, tuple):
        name = name[0]
    return get_identifier(name)[:2] + get_identifier(name)[-2:]
Пример #10
0
def get_day(n):
    return int(get_identifier(n)[2])
Пример #11
0
def day_info(name):
    if isinstance(name, tuple):
        name = name[0]
    return int(get_identifier(name)[2])
Пример #12
0
def well_info(name):
    if isinstance(name, tuple):
        name = name[0]
    return get_identifier(name)[:2] + get_identifier(name)[3:]
Пример #13
0
def well_info(name):
    return get_identifier(name)[:2] + get_identifier(name)[3:]
Пример #14
0
def get_pairs(inds, label_ind, startday_range=(4, 12)):
    if int(id_mapping[label_ind][2]) < 10:
        return []
    start_inds = [ind for ind in inds if \
        int(id_mapping[ind][2]) >= startday_range[0] and \
        int(id_mapping[ind][2]) <= startday_range[1] and \
        int(id_mapping[ind][2]) <= int(id_mapping[label_ind][2]) - 3]
    return [(i, label_ind) for i in start_inds]


# Validity of samples
flags = {
    i: check_valid_for_0_to_inf_training(i)
    for i in base_dataset.selected_inds
}
id_mapping = {i: get_identifier(base_dataset.names[i]) for i in flags}

# Validity of wells
valid_wells = sorted(
    set([
        get_identifier(base_dataset.names[i])[:2] +
        get_identifier(base_dataset.names[i])[3:] for i in flags if flags[i][1]
    ]))

quest_pairs = []
extra_pairs = []
for well in valid_wells:
    related_inds = [
        i for i in flags
        if flags[i][0] and id_mapping[i][:2] + id_mapping[i][3:] == well
    ]
Пример #15
0
def extract_samples_for_inspection(pairs,
                                   inter_dir,
                                   image_output_dir,
                                   seed=123):
    if not seed is None:
        np.random.seed(seed)
    if not os.path.exists(image_output_dir):
        os.makedirs(image_output_dir, exist_ok=True)
    raw_id_to_f_mapping = {get_identifier(p[0]): p for p in pairs}

    fs = os.listdir(inter_dir)
    # Check existence of identifier file
    assert 'names.pkl' in fs
    names = pickle.load(open(os.path.join(inter_dir, 'names.pkl'), 'rb'))
    for i, n in names.items():
        assert get_identifier(n) in raw_id_to_f_mapping

    # Check phase contrast files
    phase_contrast_files = [
        f for f in fs if f.startswith('X_') and f.endswith('.pkl')
    ]
    for i in range(len(phase_contrast_files)):
        assert 'X_%d.pkl' % i in fs

    # Sample phase contrast image
    os.makedirs(os.path.join(image_output_dir, "phase_contrast"),
                exist_ok=True)
    random_inds = np.random.choice(list(names.keys()), (50, ), replace=False)
    for ind in random_inds:
        file_ind = ind // 100
        identifier = get_identifier(names[ind])
        try:
            processed_img = pickle.load(
                open(os.path.join(inter_dir, 'X_%d.pkl' % file_ind),
                     'rb'))[ind]
            raw_img = load_image(raw_id_to_f_mapping[identifier][0])
            out_path = os.path.join(image_output_dir, "phase_contrast",
                                    "%s.png" % '_'.join(identifier))
            save_multi_panel_fig([raw_img, processed_img], out_path)

        except Exception:
            print("Error saving sample %s" % '_'.join(identifier))

    # try:
    #     # Check discrete segmentation annotations
    #     assert "classify_discrete_labels.pkl" in fs
    #     for i in range(len(phase_contrast_files)):
    #         assert 'segment_discrete_y_%d.pkl' % i in fs
    #         assert 'segment_discrete_w_%d.pkl' % i in fs

    #     classify_discrete_labels = pickle.load(open(os.path.join(inter_dir, "classify_discrete_labels.pkl"), 'rb'))
    #     inds_by_class = {}
    #     for k in classify_discrete_labels:
    #         if classify_discrete_labels[k][0] is None or classify_discrete_labels[k][1] == 0:
    #             continue
    #         label = classify_discrete_labels[k][0]
    #         if not label in inds_by_class:
    #             inds_by_class[label] = []
    #         inds_by_class[label].append(k)

    #     # Sample discrete fl segmentation (by class)
    #     for cl in inds_by_class:
    #         os.makedirs(os.path.join(image_output_dir, "discrete_segmentation_class_%s" % str(cl)), exist_ok=True)
    #         if len(inds_by_class[cl]) > 20:
    #             random_inds = np.random.choice(list(inds_by_class[cl]), (20,), replace=False)
    #         else:
    #             random_inds = inds_by_class[cl]
    #         for ind in random_inds:
    #             file_ind = ind // 100
    #             identifier = get_identifier(names[ind])
    #             try:
    #                 raw_pc = load_image(raw_id_to_f_mapping[identifier][0])
    #                 raw_fl = load_image(raw_id_to_f_mapping[identifier][1])
    #                 processed_pc = pickle.load(open(os.path.join(inter_dir, 'X_%d.pkl' % file_ind), 'rb'))[ind]
    #                 processed_fl_y = pickle.load(open(os.path.join(inter_dir, 'segment_discrete_y_%d.pkl' % file_ind), 'rb'))[ind]
    #                 processed_fl_w = pickle.load(open(os.path.join(inter_dir, 'segment_discrete_w_%d.pkl' % file_ind), 'rb'))[ind]
    #                 out_path = os.path.join(image_output_dir,
    #                                         "discrete_segmentation_class_%s" % str(cl),
    #                                         "%s.png" % '_'.join(identifier))
    #                 save_multi_panel_fig([raw_pc,
    #                                       processed_pc,
    #                                       raw_fl,
    #                                       None,
    #                                       processed_fl_y,
    #                                       processed_fl_w], out_path)
    #             except Exception:
    #                 print("Error saving fl(discrete) sample %s" % '_'.join(identifier))
    # except Exception:
    #     print("Issue locating discrete segmentation files")

    try:
        # Check continuous segmentation annotations
        assert "classify_continuous_labels.pkl" in fs
        for i in range(len(phase_contrast_files)):
            assert 'segment_continuous_y_%d.pkl' % i in fs
            assert 'segment_continuous_w_%d.pkl' % i in fs

        classify_continuous_labels = pickle.load(
            open(os.path.join(inter_dir, "classify_continuous_labels.pkl"),
                 'rb'))
        inds_by_class = {}
        for k in classify_continuous_labels:
            if classify_continuous_labels[k][
                    0] is None or classify_continuous_labels[k][1] == 0:
                continue
            label = np.argmax(classify_continuous_labels[k][0])
            if not label in inds_by_class:
                inds_by_class[label] = []
            inds_by_class[label].append(k)

        # Sample continuous fl segmentation (by class)
        for cl in inds_by_class:
            os.makedirs(os.path.join(
                image_output_dir,
                "continuous_segmentation_class_%s" % str(cl)),
                        exist_ok=True)
            if len(inds_by_class[cl]) > 20:
                random_inds = np.random.choice(list(inds_by_class[cl]), (20, ),
                                               replace=False)
            else:
                random_inds = inds_by_class[cl]
            for ind in random_inds:
                file_ind = ind // 100
                identifier = get_identifier(names[ind])
                try:
                    raw_pc = load_image(raw_id_to_f_mapping[identifier][0])
                    raw_fl = load_image(raw_id_to_f_mapping[identifier][1])
                    processed_pc = pickle.load(
                        open(os.path.join(inter_dir, 'X_%d.pkl' % file_ind),
                             'rb'))[ind]
                    processed_fl_y = pickle.load(
                        open(
                            os.path.join(
                                inter_dir,
                                'segment_continuous_y_%d.pkl' % file_ind),
                            'rb'))[ind]
                    processed_fl_y = (
                        processed_fl_y *
                        np.array([0., 1. / 3, 2. / 3, 1.]).reshape(
                            (1, 1, 4))).sum(2)
                    processed_fl_w = pickle.load(
                        open(
                            os.path.join(
                                inter_dir,
                                'segment_continuous_w_%d.pkl' % file_ind),
                            'rb'))[ind]
                    out_path = os.path.join(
                        image_output_dir,
                        "continuous_segmentation_class_%s" % str(cl),
                        "%s.png" % '_'.join(identifier))
                    save_multi_panel_fig([
                        raw_pc, processed_pc, raw_fl, None, processed_fl_y,
                        processed_fl_w
                    ], out_path)
                except Exception:
                    print("Error saving fl(continuous) sample %s" %
                          '_'.join(identifier))
    except Exception:
        print("Issue locating continuous segmentation files")
Пример #16
0
def preprocess(
        pairs,
        output_path=None,
        preprocess_filter=lambda x: True,
        target_size=(384, 288),
        labels=['discrete', 'continuous'],
        raw_label_preprocess=lambda x: x,
        nonneg_thr=65535,
        well_setting='96well-3',  #'6well-15', '6well-14' or '96well-3'
        linear_align=False,
        shuffle=True,
        seed=None):
    if not seed is None:
        np.random.seed(seed)

    # Sanity check
    pairs = [p for p in pairs if p[0] is not None and preprocess_filter(p)]
    for p in pairs:
        if p[1] is not None:
            assert get_identifier(p[0]) == get_identifier(p[1])

    # Sort
    pairs = sorted(pairs)
    if shuffle:
        np.random.shuffle(pairs)

    # Featurize data
    cv2_shape = target_size
    np_shape = (target_size[1], target_size[0], -1
                )  # Note that cv2 and numpy have reversed axis ordering
    names = {}
    Xs = {}
    segment_discrete_ys = {}
    segment_discrete_ws = {}
    segment_continuous_ys = {}
    segment_continuous_ws = {}
    classify_discrete_labels = {}
    classify_continuous_labels = {}
    file_ind = 0
    for ind, pair in enumerate(pairs):
        identifier = get_identifier(pair[0])
        names[ind] = pair[0]
        try:
            # Input feature (phase contrast image)
            pair_dat = load_image_pair(pair)
            position_code = identifier[-1]
            if well_setting == '96well-3' and position_code in [
                    '1', '3', '7', '9'
            ] and pair_dat[1] is not None:
                mask = generate_mask(pair_dat)
            else:
                mask = np.ones_like(pair_dat[0])
            X = adjust_contrast(pair_dat,
                                mask,
                                position_code,
                                linear_align=linear_align &
                                (well_setting == '96well-3'))
            X = cv2.resize(X, cv2_shape)

            # Segment weights
            w = generate_weight(mask,
                                position_code,
                                linear_align=(well_setting == '96well-3'))
            w = cv2.resize(w, cv2_shape)

            # Segment labels (binarized fluorescence, discrete labels)
            pair_dat = [pair_dat[0], raw_label_preprocess(pair_dat[1])]
            Xs[ind] = X.reshape(np_shape).astype(float)

        except Exception as e:
            print("ERROR in loading pair %s" % str(identifier))
            print(e)
            Xs[ind] = None

        if not pair_dat[1] is None and 'discrete' in labels:
            try:
                y, discrete_w = generate_discrete_labels(pair_dat,
                                                         mask,
                                                         cv2_shape,
                                                         w,
                                                         nonneg_thr=nonneg_thr)
                segment_discrete_ys[ind] = y.reshape(np_shape).astype(int)
                segment_discrete_ws[ind] = discrete_w.reshape(np_shape).astype(
                    float)
            except Exception as e:
                print("ERROR in generating fluorescence label %s" %
                      str(identifier))
                print(e)
                segment_discrete_ys[ind] = None
                segment_discrete_ws[ind] = None
        else:
            segment_discrete_ys[ind] = None
            segment_discrete_ws[ind] = None

        # Segment labels (continuous fluorescence in 4 classes)
        if not pair_dat[1] is None and 'continuous' in labels:
            try:
                y, continuous_w = generate_continuous_labels(
                    pair_dat, mask, cv2_shape, w, nonneg_thr=nonneg_thr)
                segment_continuous_ys[ind] = y.reshape(np_shape).astype(float)
                segment_continuous_ws[ind] = continuous_w.reshape(
                    np_shape).astype(float)

                classify_continuous_y = segment_continuous_ys[ind].sum((0, 1))
                classify_continuous_y = classify_continuous_y / (
                    1e-5 + np.sum(classify_continuous_y))
            except Exception as e:
                print("ERROR in generating fluorescence label %s" %
                      str(identifier))
                print(e)
                segment_continuous_ys[ind] = None
                segment_continuous_ws[ind] = None
                classify_continuous_y = None
        else:
            segment_continuous_ys[ind] = None
            segment_continuous_ws[ind] = None
            classify_continuous_y = None

        # Classify labels
        classify_discrete_labels[ind] = binarized_fluorescence_label(
            segment_discrete_ys[ind], segment_discrete_ws[ind])

        # Continuous label (4-class) will be dependent on fluorescence intensity level
        thrs = np.array([0., 0.25, 0.35, 0.65])
        _classify_continuous_w = classify_discrete_labels[ind][1]
        if classify_discrete_labels[ind][
                0] is None or _classify_continuous_w == 0:
            _classify_continuous_y = None
        elif classify_discrete_labels[ind][0] == 0:
            _classify_continuous_y = np.array([1., 0., 0., 0.])
        else:
            assert classify_continuous_y is not None
            _fl_intensity_lev = (classify_continuous_y *
                                 np.array([0., 0.5, 1., 3.])).sum()
            _classify_continuous_y = np.exp(-np.abs(thrs - _fl_intensity_lev) /
                                            0.2)
            _classify_continuous_y = _classify_continuous_y / _classify_continuous_y.sum(
            )
        classify_continuous_labels[ind] = (_classify_continuous_y,
                                           _classify_continuous_w)

        # Save data
        if output_path is not None and ((ind % 100 == 99) or
                                        (ind == len(pairs) - 1)):
            assert len(Xs) <= 100
            print("Writing file %d" % file_ind)
            with open(os.path.join(output_path, 'names.pkl'), 'wb') as f:
                pickle.dump(names, f)
            with open(os.path.join(output_path, 'X_%d.pkl' % file_ind),
                      'wb') as f:
                pickle.dump(Xs, f)
            if 'discrete' in labels:
                with open(
                        os.path.join(output_path,
                                     'segment_discrete_y_%d.pkl' % file_ind),
                        'wb') as f:
                    pickle.dump(segment_discrete_ys, f)
                with open(
                        os.path.join(output_path,
                                     'segment_discrete_w_%d.pkl' % file_ind),
                        'wb') as f:
                    pickle.dump(segment_discrete_ws, f)
                with open(
                        os.path.join(output_path,
                                     'classify_discrete_labels.pkl'),
                        'wb') as f:
                    pickle.dump(classify_discrete_labels, f)
            if 'continuous' in labels:
                with open(
                        os.path.join(output_path,
                                     'segment_continuous_y_%d.pkl' % file_ind),
                        'wb') as f:
                    pickle.dump(segment_continuous_ys, f)
                with open(
                        os.path.join(output_path,
                                     'segment_continuous_w_%d.pkl' % file_ind),
                        'wb') as f:
                    pickle.dump(segment_continuous_ws, f)
                with open(
                        os.path.join(output_path,
                                     'classify_continuous_labels.pkl'),
                        'wb') as f:
                    pickle.dump(classify_continuous_labels, f)
            file_ind += 1
            Xs = {}
            segment_discrete_ys = {}
            segment_discrete_ws = {}
            segment_continuous_ys = {}
            segment_continuous_ws = {}
    return file_ind
Пример #17
0
    'segment_label_type': 'discrete',
    'n_classify_classes': 2,
    'classify_class_weights': [0.5, 0.15],
    'classify_label_type': 'discrete',
}

base_gen = PairGenerator(
    name_file,
    X_filenames,
    segment_y_files=y_filenames,
    segment_w_files=w_filenames,
    classify_label_file=label_file,
    augment=True,
    **kwargs)

get_ex = lambda x: get_identifier(x)[:2]
all_exs = set(get_ex(n[0]) for i, n in base_gen.names.items())
for ex in all_exs:
    if not ex in VALID_DIRS or not ex in MODEL_DIRS:
        continue
    save_dir = VALID_DIRS[ex]
    model_dir = MODEL_DIRS[ex]
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(model_dir, exist_ok=True)
    
    # Setup train/valid datasets
    train_inds = [i for i, n in base_gen.names.items() if get_ex(n[0]) != ex]
    valid_inds = [i for i, n in base_gen.names.items() if get_ex(n[0]) == ex]
    assert len(train_inds) + len(valid_inds) == base_gen.N
    assert len(set(train_inds) & set(valid_inds)) == 0
    print("Valid with %s: %d / %d" % (str(ex), len(train_inds), len(valid_inds)))