def nn_features(study, nets, model_name, slice_type, frames_preproc):
    nns = nets[slice_type]
    names = np.array([
        '%s_%s_bucket_%s' % (model_name, slice_type, i)
        for i in range(len(nns))
    ])
    buckets = nn_pipeline_multi.get_buckets(slice_type)
    assert len(buckets) - 1 == len(
        nns), 'number of buckets and nns does not match'
    features = np.zeros(len(nns))
    if slice_type not in study:
        return features, names
    for bucket_id in range(len(nns)):
        keep_slice = nn_pipeline_multi.keep_slice_location(bucket_id, buckets)
        slices = study[slice_type].keys()
        pred, num = 0.0, 0
        for slice_id in slices:
            if keep_slice(study[slice_type][slice_id]):
                frames = common.slice_to_numpy(study[slice_type][slice_id])
                frames = frames_preproc(frames)
                pred += validate.net_output(nns[bucket_id], frames)
                num += 1
        if num > 0:
            pred /= num
        features[bucket_id] = pred
    return features, names
Example #2
0
def create_db_unilabel(db_folder,
                       processed_data_folder,
                       frames_preproc=None,
                       trans_generator=lambda a, b: [None],
                       label_key='systole',
                       slice_type='sax',
                       keep_slice=lambda study: True):
    """Creates LMDB for data and labels, where labes is a single value.

    db_folder: the LMDB destination for the data.
    db_label_folder: the LMDB destination for the labels.
    processed_data_folder: the directory of .pkl files written by data_pipeline.py.
    frames_preproc: any chosen preprocessing function for the frames.
    trans_generator: a function that takes in study and slice id and returns a set
        of transformations to include in the dataset.
    label_key: 'systole' or 'diastole'.
    keep_slice: a function that decides if to keep a slice or not.

    """
    frames_preproc = frames_preproc or default_preproc

    if os.path.exists(db_folder):
        shutil.rmtree(db_folder)

    data_db = lmdb.open(db_folder, map_size=1e12)

    # txn is a transaction object
    with data_db.begin(write=True) as data_txn:
        counter = 0
        for study_file in os.listdir(processed_data_folder):
            assert study_file.endswith(
                '.pkl'), 'file %s has wrong extension' % study_file
            with open(os.path.join(processed_data_folder, study_file),
                      'rb') as f:
                study = pickle.load(f)
                study_id = study['study']
                if slice_type not in study:
                    print >> sys.stderr, 'study_id %s does not have %s' % (
                        study_id, slice_type)
                    continue
                slices = study[slice_type].keys()
                for slice_id in slices:
                    if not keep_slice(study[slice_type][slice_id]):
                        continue
                    for trans in trans_generator(study_id, slice_id):
                        volume = study[label_key]
                        volume, frames = apply_transform(
                            volume,
                            common.slice_to_numpy(study[slice_type][slice_id]),
                            trans)
                        frames = frames_preproc(frames)
                        label = int(volume)
                        data = caffe.io.array_to_datum(frames, label)
                        str_id = '{:06d}'.format(counter)
                        data_txn.put(str_id, data.SerializeToString())
                        counter += 1
        print '%s, added %s slices' % (db_folder, counter)

    data_db.close()
Example #3
0
def create_db_unilabel(db_folder, processed_data_folder,
              frames_preproc=None,
              trans_generator=lambda a, b: [None],
              label_key='systole',
              slice_type='sax',
              keep_slice=lambda study: True):
    """Creates LMDB for data and labels, where labes is a single value.

    db_folder: the LMDB destination for the data.
    db_label_folder: the LMDB destination for the labels.
    processed_data_folder: the directory of .pkl files written by data_pipeline.py.
    frames_preproc: any chosen preprocessing function for the frames.
    trans_generator: a function that takes in study and slice id and returns a set
        of transformations to include in the dataset.
    label_key: 'systole' or 'diastole'.
    keep_slice: a function that decides if to keep a slice or not.

    """
    frames_preproc = frames_preproc or default_preproc

    if os.path.exists(db_folder):
        shutil.rmtree(db_folder)

    data_db = lmdb.open(db_folder, map_size=1e12)

    # txn is a transaction object
    with data_db.begin(write=True) as data_txn:
        counter = 0
        for study_file in os.listdir(processed_data_folder):
            assert study_file.endswith('.pkl'), 'file %s has wrong extension' % study_file
            with open(os.path.join(processed_data_folder, study_file), 'rb') as f:
                study = pickle.load(f)
                study_id = study['study']
                if slice_type not in study:
                    print >>sys.stderr, 'study_id %s does not have %s' % (study_id, slice_type)
                    continue
                slices = study[slice_type].keys()
                for slice_id in slices:
                    if not keep_slice(study[slice_type][slice_id]):
                        continue
                    for trans in trans_generator(study_id, slice_id):
                        volume = study[label_key]
                        volume, frames = apply_transform(
                            volume,
                            common.slice_to_numpy(study[slice_type][slice_id]),
                            trans)
                        frames = frames_preproc(frames)
                        label = int(volume)
                        data = caffe.io.array_to_datum(frames, label)
                        str_id = '{:06d}'.format(counter)
                        data_txn.put(str_id, data.SerializeToString())
                        counter += 1
        print '%s, added %s slices' % (db_folder, counter)

    data_db.close()
def create_data(x_file,
                y_file,
                processed_data_folder,
                frames_preproc=None,
                slice_type='sax',
                one_slice_per_study=False,
                ids_file=None):
    """Creates numpy arrays for use by Keras.

    x_file: the destination file for the x numpy array.
    y_file: the destination file for the y (systole, diastole) numpy array.
    processed_data_folder: the directory of .pkl files written by data_pipeline.py.
    frames_preproc: any chosen preprocessing function for the frames.
    one_slice_per_study: only take one slice form each study
    ids_file: save ids to file

    """
    frames_preproc = frames_preproc or create_lmdb.default_preproc
    X = []
    y = []
    ids = []

    for study_file in os.listdir(processed_data_folder):
        assert study_file.endswith(
            '.pkl'), 'file %s has wrong extension' % study_file
        with open(os.path.join(processed_data_folder, study_file), 'rb') as f:
            study = pickle.load(f)
            study_id = study['study']
            if slice_type not in study:
                print >> sys.stderr, 'study_id %s does not have %s' % (
                    study_id, slice_type)
                continue
            slices = study[slice_type].keys()
            if one_slice_per_study:
                slices = slices[:1]
            for slice_ in slices:
                systole, diastole = study['systole'], study['diastole']
                frames = common.slice_to_numpy(study[slice_type][slice_])
                frames = frames_preproc(frames)
                X.append(frames)
                y.append([systole, diastole])
                ids.append(study_id)
                if len(X) % 100 == 0:
                    print 'added', len(X), 'data points to db'

    X = np.array(X)
    y = np.array(y)
    ids = np.array(ids)
    np.save(x_file, X)
    np.save(y_file, y)
    if ids_file:
        np.save(ids_file, ids)
Example #5
0
def write_nn_submission(out_file,
                        processed_data_folder,
                        sys_net,
                        dia_net,
                        default_sys=None,
                        default_dia=None,
                        frames_preproc=None):

    frames_preproc = frames_preproc or create_lmdb.default_preproc
    default_sys = np.zeros(600) if default_sys is None else default_sys
    default_dia = np.zeros(600) if default_dia is None else default_dia

    with open(out_file, 'w') as out:
        submission = csv.writer(out, lineterminator='\n')
        submission.writerow(['Id'] + ['P%d' % i for i in xrange(600)])
        for i in xrange(501, 701):
            print i
            study_file = os.path.join(processed_data_folder, 'study%d.pkl' % i)
            assert study_file.endswith('.pkl'), \
                'file %s has wrong extension' % study_file
            with open(os.path.join(processed_data_folder, study_file),
                      'rb') as f:
                study = pickle.load(f)
                study_id = study['study']
                assert study_id == i, i

                if 'sax' not in study:
                    print >> sys.stderr, 'no slices for study %d' % study_id
                    submission.writerow(['%d_Diastole' % study_id] +
                                        submission_helper(default_dia))
                    submission.writerow(['%d_Systole' % study_id] +
                                        submission_helper(default_sys))
                    continue

                study_data = study['sax']
                for label_key, net in zip(['Diastole', 'Systole'],
                                          [dia_net, sys_net]):
                    net_predictions = np.array([
                        net_output(
                            net,
                            frames_preproc(
                                common.slice_to_numpy(study_data[slice_])))
                        for slice_ in study_data
                    ])
                    final_predictions = (np.mean(net_predictions, axis=0)
                                         if len(net_predictions.shape) > 1 else
                                         net_predictions)
                    submission.writerow(['%d_%s' % (study_id, label_key)] +
                                        submission_helper(final_predictions))
def create_data(x_file, y_file, processed_data_folder,
                frames_preproc=None,
                slice_type='sax',
                one_slice_per_study=False,
                ids_file=None):
    """Creates numpy arrays for use by Keras.

    x_file: the destination file for the x numpy array.
    y_file: the destination file for the y (systole, diastole) numpy array.
    processed_data_folder: the directory of .pkl files written by data_pipeline.py.
    frames_preproc: any chosen preprocessing function for the frames.
    one_slice_per_study: only take one slice form each study
    ids_file: save ids to file

    """
    frames_preproc = frames_preproc or create_lmdb.default_preproc
    X = []
    y = []
    ids = []

    for study_file in os.listdir(processed_data_folder):
        assert study_file.endswith('.pkl'), 'file %s has wrong extension' % study_file
        with open(os.path.join(processed_data_folder, study_file), 'rb') as f:
            study = pickle.load(f)
            study_id = study['study']
            if slice_type not in study:
                print >>sys.stderr, 'study_id %s does not have %s' % (study_id, slice_type)
                continue
            slices = study[slice_type].keys()
            if one_slice_per_study:
                slices = slices[:1]
            for slice_ in slices:
                systole, diastole = study['systole'], study['diastole']
                frames = common.slice_to_numpy(study[slice_type][slice_])
                frames = frames_preproc(frames)
                X.append(frames)
                y.append([systole, diastole])
                ids.append(study_id)
                if len(X) % 100 == 0:
                    print 'added', len(X), 'data points to db'

    X = np.array(X)
    y = np.array(y)
    ids = np.array(ids)
    np.save(x_file, X)
    np.save(y_file, y)
    if ids_file:
        np.save(ids_file, ids)
def nn_features(study, nets, model_name, slice_type, frames_preproc):
    nns = nets[slice_type]
    names = np.array(['%s_%s_bucket_%s' % (model_name, slice_type, i) for i in range(len(nns))])
    buckets = nn_pipeline_multi.get_buckets(slice_type)
    assert len(buckets) - 1 == len(nns), 'number of buckets and nns does not match'
    features = np.zeros(len(nns))
    if slice_type not in study:
        return features, names
    for bucket_id in range(len(nns)):
        keep_slice = nn_pipeline_multi.keep_slice_location(bucket_id, buckets)
        slices = study[slice_type].keys()
        pred, num = 0.0, 0
        for slice_id in slices:
            if keep_slice(study[slice_type][slice_id]):
                frames = common.slice_to_numpy(study[slice_type][slice_id])
                frames = frames_preproc(frames)
                pred += validate.net_output(nns[bucket_id], frames)
                num += 1
        if num > 0:
            pred /= num
        features[bucket_id] = pred
    return features, names
Example #8
0
def create_db(db_folder,
              db_label_folder,
              processed_data_folder,
              frames_preproc=None,
              trans_generator=lambda a, b: [None],
              label_key='systole',
              slice_type='sax',
              num_total_data=None,
              one_slice_per_study=False):
    """Creates LMDB for data and labels.

    db_folder: the LMDB destination for the data.
    db_label_folder: the LMDB destination for the labels.
    processed_data_folder: the directory of .pkl files written by data_pipeline.py.
    frames_preproc: any chosen preprocessing function for the frames.
    trans_generator: a function that takes in study and slice id and returns a set
        of transformations to include in the dataset.
    label_key: 'systole' or 'diastole'.
    num_total_data: total amount of data, if known.  Can't shuffle data without this.
    one_slice_per_study: only take one slice form each study

    """
    frames_preproc = frames_preproc or default_preproc

    for folder in [db_folder, db_label_folder]:
        if os.path.exists(folder):
            shutil.rmtree(folder)

    data_db = lmdb.open(db_folder, map_size=1e12)
    label_db = lmdb.open(db_label_folder, map_size=1e12)

    if num_total_data:
        shuffled_counters = range(num_total_data)
        random.shuffle(shuffled_counters)

    # txn is a transaction object
    with data_db.begin(write=True) as data_txn:
        with label_db.begin(write=True) as label_txn:
            counter = 0
            for study_file in os.listdir(processed_data_folder):
                assert study_file.endswith(
                    '.pkl'), 'file %s has wrong extension' % study_file
                with open(os.path.join(processed_data_folder, study_file),
                          'rb') as f:
                    study = pickle.load(f)
                    study_id = study['study']
                    if slice_type not in study:
                        print >> sys.stderr, 'study_id %s does not have %s' % (
                            study_id, slice_type)
                        continue
                    slices = study[slice_type].keys()
                    if one_slice_per_study:
                        slices = slices[:1]
                    for slice_ in slices:
                        for trans in trans_generator(study_id, slice_):
                            volume = study[label_key]
                            volume, frames = apply_transform(
                                volume,
                                common.slice_to_numpy(
                                    study[slice_type][slice_]), trans)
                            frames = frames_preproc(frames)
                            data = caffe.io.array_to_datum(frames)
                            label = np.array(volume <= np.arange(600),
                                             dtype=np.uint8).reshape(
                                                 (1, 1, 600))
                            label = caffe.io.array_to_datum(label)
                            if num_total_data:
                                str_id = '{:06d}'.format(
                                    shuffled_counters[counter])
                            else:
                                str_id = '{:06d}'.format(counter)
                            data_txn.put(str_id, data.SerializeToString())
                            label_txn.put(str_id, label.SerializeToString())
                            counter += 1
                            if counter % 100 == 0:
                                print 'added', counter, 'data points to db'

    label_db.close()
    data_db.close()
Example #9
0
def create_db(db_folder, db_label_folder, processed_data_folder,
              frames_preproc=None,
              trans_generator=lambda a, b: [None],
              label_key='systole',
              slice_type='sax',
              num_total_data=None,
              one_slice_per_study=False):
    """Creates LMDB for data and labels.

    db_folder: the LMDB destination for the data.
    db_label_folder: the LMDB destination for the labels.
    processed_data_folder: the directory of .pkl files written by data_pipeline.py.
    frames_preproc: any chosen preprocessing function for the frames.
    trans_generator: a function that takes in study and slice id and returns a set
        of transformations to include in the dataset.
    label_key: 'systole' or 'diastole'.
    num_total_data: total amount of data, if known.  Can't shuffle data without this.
    one_slice_per_study: only take one slice form each study

    """
    frames_preproc = frames_preproc or default_preproc

    for folder in [db_folder, db_label_folder]:
        if os.path.exists(folder):
            shutil.rmtree(folder)

    data_db = lmdb.open(db_folder, map_size=1e12)
    label_db = lmdb.open(db_label_folder, map_size=1e12)

    if num_total_data:
        shuffled_counters = range(num_total_data)
        random.shuffle(shuffled_counters)

    # txn is a transaction object
    with data_db.begin(write=True) as data_txn:
        with label_db.begin(write=True) as label_txn:
            counter = 0
            for study_file in os.listdir(processed_data_folder):
                assert study_file.endswith('.pkl'), 'file %s has wrong extension' % study_file
                with open(os.path.join(processed_data_folder, study_file), 'rb') as f:
                    study = pickle.load(f)
                    study_id = study['study']
                    if slice_type not in study:
                        print >>sys.stderr, 'study_id %s does not have %s' % (study_id, slice_type)
                        continue
                    slices = study[slice_type].keys()
                    if one_slice_per_study:
                        slices = slices[:1]
                    for slice_ in slices:
                        for trans in trans_generator(study_id, slice_):
                            volume = study[label_key]
                            volume, frames = apply_transform(
                                volume,
                                common.slice_to_numpy(study[slice_type][slice_]),
                                trans)
                            frames = frames_preproc(frames)
                            data = caffe.io.array_to_datum(frames)
                            label = np.array(volume <= np.arange(600),
                                             dtype=np.uint8).reshape((1, 1, 600))
                            label = caffe.io.array_to_datum(label)
                            if num_total_data:
                                str_id = '{:06d}'.format(shuffled_counters[counter])
                            else:
                                str_id = '{:06d}'.format(counter)
                            data_txn.put(str_id, data.SerializeToString())
                            label_txn.put(str_id, label.SerializeToString())
                            counter += 1
                            if counter % 100 == 0:
                                print 'added', counter, 'data points to db'

    label_db.close()
    data_db.close()