예제 #1
0
def sample_patches(lid, imgpath, lblpath):
    image = sitk.ReadImage(imgpath)
    image = resample(image, (1.0, 1.0, 1.0), interpolator = sitk.sitkLinear)
    img_arr = sitk.GetArrayFromImage(image)
    img_arr = snd.zoom(img_arr, zoom = (0.5, 0.5, 0.5), order = 1)
    img_arr = np.float32(np.clip(img_arr, -100, 400))
    img_arr = np.uint8(255*(img_arr + 100)/(500))
    img_arr = np.pad(img_arr, ((100,100),(100,100),(100,100)), mode = 'constant')
        
    label = sitk.ReadImage(lblpath)
    label = resample(label, (1.0, 1.0, 1.0), interpolator = sitk.sitkNearestNeighbor)
    lbl_arr = sitk.GetArrayFromImage(label)
    lbl_arr[lbl_arr == 2] = 1
    lbl_arr = np.uint8(snd.zoom(lbl_arr, zoom = (0.5, 0.5, 0.5), order = 0))
    lbl_arr_cp = lbl_arr.copy() + 1
    lbl_arr = np.pad(lbl_arr, ((100,100),(100,100),(100,100)), mode = 'constant')
    lbl_arr_cp = np.pad(lbl_arr_cp, ((100,100),(100,100),(100,100)), mode = 'constant')
    lbl_arr_cp -= 1
    
    class1_locs = uniform_sample(lbl_arr_cp == 0, 50)
    class2_locs = uniform_sample(lbl_arr_cp == 1, 50)
#     print(' class 1, class 2 :', len(class1_locs), len(class2_locs))
    locs = class1_locs[:5] + class2_locs[:45]
    random.shuffle(locs)
    
    patch_size, lbl_size = [116, 132, 132], [28, 44, 44]
    liver_pixel_count = {}
    for idx, l in enumerate(locs):
        l = adjust_center_for_boundaries(l, patch_size, img_arr.shape)
        img_patch = extract_patch(img_arr, l, patch_size)
        lbl_patch = extract_patch(lbl_arr, l, lbl_size)
        liver_pixel_count[idx] = np.sum(lbl_patch)
        save_dir = './data/train'
        inppname = 'img' + str(lid) + '_input'+str(idx)+'.npy'
        tgtpname = 'img' + str(lid) + '_label'+str(idx)+'.npy'
        np.save(os.path.join(save_dir, inppname), img_patch)
        np.save(os.path.join(save_dir, tgtpname), lbl_patch)
def sample_patches(lid, imgpath, lblpath):
    '''
    Extract Patches

    Parameters:
        imgpath - path of the image

        lblpath - path of the corressponding segmented image

        lid - 

    '''

    #Image

    #Read the image and return an object (generator comprehension), Here 'image' is a sitk's object
    image = sitk.ReadImage(imgpath)
    image = resample(image, (1.0, 1.0, 1.0), interpolator=sitk.sitkLinear)
    #get the pixel values from the image
    img_arr = sitk.GetArrayFromImage(image)
    #Zooms the image according to the zoom
    img_arr = snd.zoom(img_arr, zoom=(0.5, 0.5, 0.5), order=1)
    #pixel values are clipped b/w (-100, 400) and converted to float32
    img_arr = np.float32(np.clip(img_arr, -100,
                                 400))  #Why clipping b/w (-100,400)?
    #Some kind of data preprocessing
    img_arr = np.uint8(255 * (img_arr + 100) / (500))
    #Pad the image
    img_arr = np.pad(img_arr, ((100, 100), (100, 100), (100, 100)),
                     mode='constant')

    #Label
    label = sitk.ReadImage(lblpath)
    label = resample(label, (1.0, 1.0, 1.0),
                     interpolator=sitk.sitkNearestNeighbor)
    lbl_arr = sitk.GetArrayFromImage(label)
    lbl_arr[lbl_arr == 2] = 1
    lbl_arr = np.uint8(snd.zoom(lbl_arr, zoom=(0.5, 0.5, 0.5), order=0))
    #Copies the content of 'lbl_arr' and adds '1' to it
    lbl_arr_cp = lbl_arr.copy() + 1
    lbl_arr = np.pad(lbl_arr, ((100, 100), (100, 100), (100, 100)),
                     mode='constant')
    lbl_arr_cp = np.pad(lbl_arr_cp, ((100, 100), (100, 100), (100, 100)),
                        mode='constant')
    lbl_arr_cp -= 1

    #Getting the crops for liver class and non-liver class
    class1_locs = uniform_sample(lbl_arr_cp == 0, 50)
    class2_locs = uniform_sample(lbl_arr_cp == 1, 50)
    #     print(' class 1, class 2 :', len(class1_locs), len(class2_locs))
    locs = class1_locs[:5] + class2_locs[:45]
    random.shuffle(locs)

    patch_size, lbl_size = [116, 132, 132], [28, 44, 44]
    liver_pixel_count = {}
    for idx, l in enumerate(locs):
        l = adjust_center_for_boundaries(l, patch_size, img_arr.shape)
        img_patch = extract_patch(img_arr, l, patch_size)
        lbl_patch = extract_patch(lbl_arr, l, lbl_size)
        liver_pixel_count[idx] = np.sum(lbl_patch)
        save_dir = './data/train'
        inppname = 'img' + str(lid) + '_input' + str(idx) + '.npy'
        tgtpname = 'img' + str(lid) + '_label' + str(idx) + '.npy'
        np.save(os.path.join(save_dir, inppname), img_patch)
        np.save(os.path.join(save_dir, tgtpname), lbl_patch)
예제 #3
0
    def inference(self, inputs):
        if config.mode is "fcn":
            fm = utils.conv_with_bn(inputs, out_channels=12, filter_size=[config.time_width, 13],
                                    stride=1, act='relu', is_training=self._is_training,
                                    padding="SAME", name="conv_1")

            fm = utils.conv_with_bn(fm, out_channels=16, filter_size=[config.time_width, 11],
                                    stride=1, act='relu', is_training=self._is_training,
                                    padding="SAME", name="conv_2")

            fm = utils.conv_with_bn(fm, out_channels=20, filter_size=[config.time_width, 9],
                                    stride=1, act='relu', is_training=self._is_training,
                                    padding="SAME", name="conv_3")

            fm_skip = utils.conv_with_bn(fm, out_channels=24, filter_size=[config.time_width, 7],
                                    stride=1, act='relu', is_training=self._is_training,
                                    padding="SAME", name="conv_4")

            fm = utils.conv_with_bn(fm_skip, out_channels=32, filter_size=[config.time_width, 7],
                                    stride=1, act='relu', is_training=self._is_training,
                                    padding="SAME", name="conv_5")

            fm = utils.conv_with_bn(fm, out_channels=24, filter_size=[config.time_width, 7],
                                    stride=1, act='relu', is_training=self._is_training,
                                    padding="SAME", name="conv_6") + fm_skip

            fm = utils.conv_with_bn(fm, out_channels=20, filter_size=[config.time_width, 9],
                                    stride=1, act='relu', is_training=self._is_training,
                                    padding="SAME", name="conv_7")

            fm = utils.conv_with_bn(fm, out_channels=16, filter_size=[config.time_width, 11],
                                    stride=1, act='relu', is_training=self._is_training,
                                    padding="SAME", name="conv_8")

            fm = utils.conv_with_bn(fm, out_channels=12, filter_size=[config.time_width, 13],
                                    stride=1, act='relu', is_training=self._is_training,
                                    padding="SAME", name="conv_9")

            fm = utils.conv_with_bn(fm, out_channels=1, filter_size=[config.time_width, config.freq_size],
                                    stride=1, act='linear', is_training=self._is_training,
                                    padding="SAME", name="conv_10")  # (batch_size, 1, config.freq_size, 1)

            # fm = utils.conv_with_bn(fm, out_channels=1, filter_size=[config.time_width, 1],
            #                         stride=1, act='linear', is_training=self._is_training,
            #                         padding="VALID", name="conv_last")
            fm = tf.squeeze(fm, [1, 3])

            return fm

        elif config.mode is "fnn":

            keep_prob = self.keep_prob

            # inputs = tf.reshape(tf.squeeze(inputs, [3]), (-1, int(config.time_width*config.freq_size)))
            # inputs = tf.nn.dropout(inputs, keep_prob=keep_prob)
            #
            # h1 = tf.nn.relu(utils.batch_norm_affine_transform(inputs, 2048, name='hidden_1',
            #                                                         is_training=self._is_training))
            # # h1 = tf.nn.dropout(h1, keep_prob=keep_prob)
            #
            # h2 = tf.nn.relu(utils.batch_norm_affine_transform(h1, 2048, name='hidden_2',
            #                                                         is_training=self._is_training))
            # # h2 = tf.nn.dropout(h2, keep_prob=keep_prob)
            #
            # # h3 = tf.nn.relu(utils.batch_norm_affine_transform(h2, 2048, name='hidden_3',
            # #                                                         is_training=self._is_training))
            # # h3 = tf.nn.dropout(h3, keep_prob=keep_prob)
            #
            # fm = utils.affine_transform(h2, config.freq_size, name='logits')

            inputs = tf.reshape(tf.squeeze(inputs, [3]), (-1, int(config.time_width*config.freq_size)))
            inputs = tf.nn.dropout(inputs, keep_prob=keep_prob)

            h1 = tf.nn.selu(utils.affine_transform(inputs, 2048, name='hidden_1'))
            h1 = tf.nn.dropout(h1, keep_prob=keep_prob)

            h2 = tf.nn.selu(utils.affine_transform(h1, 2048, name='hidden_2'))
            h2 = tf.nn.dropout(h2, keep_prob=keep_prob)

            h3 = tf.nn.selu(utils.affine_transform(h2, 2048, name='hidden_3'))
            h3 = tf.nn.dropout(h3, keep_prob=keep_prob)

            fm = utils.affine_transform(h3, config.freq_size, name='logits')

            return fm

        elif config.mode is "irm":

            keep_prob = self.keep_prob

            # inputs = tf.reshape(tf.squeeze(inputs, [3]), (-1, int(config.time_width*config.freq_size)))
            # inputs = tf.nn.dropout(inputs, keep_prob=keep_prob)
            #
            # h1 = tf.nn.relu(utils.batch_norm_affine_transform(inputs, 2048, name='hidden_1',
            #                                                         is_training=self._is_training))
            # # h1 = tf.nn.dropout(h1, keep_prob=keep_prob)
            #
            # h2 = tf.nn.relu(utils.batch_norm_affine_transform(h1, 2048, name='hidden_2',
            #                                                         is_training=self._is_training))
            # # h2 = tf.nn.dropout(h2, keep_prob=keep_prob)
            #
            # # h3 = tf.nn.relu(utils.batch_norm_affine_transform(h2, 2048, name='hidden_3',
            # #                                                         is_training=self._is_training))
            # # h3 = tf.nn.dropout(h3, keep_prob=keep_prob)
            #
            # fm = utils.affine_transform(h2, config.freq_size, name='logits')

            inputs = tf.reshape(tf.squeeze(inputs, [3]), (-1, int(config.time_width*config.freq_size)))
            inputs = tf.nn.dropout(inputs, keep_prob=keep_prob)

            h1 = tf.nn.selu(utils.affine_transform(inputs, 2048, name='hidden_1'))
            h1 = tf.nn.dropout(h1, keep_prob=keep_prob)

            h2 = tf.nn.selu(utils.affine_transform(h1, 2048, name='hidden_2'))
            h2 = tf.nn.dropout(h2, keep_prob=keep_prob)

            h3 = tf.nn.selu(utils.affine_transform(h2, 2048, name='hidden_3'))
            h3 = tf.nn.dropout(h3, keep_prob=keep_prob)

            fm = utils.affine_transform(h3, config.freq_size, name='logits')

            return fm

        elif config.mode is "sfnn":

            keep_prob = self.keep_prob
            skip_inputs = tf.squeeze(inputs[:, int(config.time_width/2), :])
            inputs = tf.reshape(tf.squeeze(inputs, [3]), (-1, int(config.time_width*config.freq_size)))
            inputs = tf.nn.dropout(inputs, keep_prob=keep_prob)

            h1 = tf.nn.selu(utils.affine_transform(inputs, 2048, name='hidden_1'))
            h1 = tf.nn.dropout(h1, keep_prob=keep_prob)

            h2 = tf.nn.selu(utils.affine_transform(h1, 2048, name='hidden_2'))
            h2 = tf.nn.dropout(h2, keep_prob=keep_prob)

            h3 = tf.nn.selu(utils.affine_transform(h2, 2048, name='hidden_3'))
            h3 = tf.nn.dropout(h3, keep_prob=keep_prob)

            fm = utils.affine_transform(h3, config.freq_size, name='logits')
            fm = fm + skip_inputs

            return fm
        elif config.mode is "lstm":

            keep_prob = self.keep_prob

            # inputs = tf.squeeze(inputs)[:, int(config.time_width/2), :]

            # inputs = tf.reshape(inputs, (-1, config.time_width, config.freq_size))  # time_width == num_steps
            # inputs = tf.nn.dropout(inputs, keep_prob=keep_prob)

            num_units = [1024, 1024]
            cells = [tf.nn.rnn_cell.LSTMCell(num_units=n, state_is_tuple=True) for n in num_units]

            cell = tf.nn.rnn_cell.MultiRNNCell(cells=cells, state_is_tuple=True)
            cell = tf.contrib.rnn.OutputProjectionWrapper(cell, output_size=config.freq_size)
            outputs, _state = tf.nn.dynamic_rnn(cell, inputs, time_major=False, dtype=tf.float32)
            fm = tf.reshape(outputs,[-1, config.freq_size])

            return fm

        elif config.mode is "tsn":
            conv_inputs = tf.squeeze(tf.transpose(inputs, [0, 2, 1, 3]), axis=3)

            keep_prob = self.keep_prob

            # inputs = tf.reshape(tf.squeeze(inputs, [3]), (-1, int(config.time_width*config.freq_size)))
            # inputs = tf.nn.dropout(inputs, keep_prob=keep_prob)
            #
            # h1 = tf.nn.relu(utils.batch_norm_affine_transform(inputs, 2048, name='hidden_1',
            #                                                         is_training=self._is_training))
            # # h1 = tf.nn.dropout(h1, keep_prob=keep_prob)
            #
            # h2 = tf.nn.relu(utils.batch_norm_affine_transform(h1, 2048, name='hidden_2',
            #                                                         is_training=self._is_training))
            # # h2 = tf.nn.dropout(h2, keep_prob=keep_prob)
            #
            # # h3 = tf.nn.relu(utils.batch_norm_affine_transform(h2, 2048, name='hidden_3',
            # #                                                         is_training=self._is_training))
            # # h3 = tf.nn.dropout(h3, keep_prob=keep_prob)
            #
            # fm = utils.affine_transform(h2, config.freq_size, name='logits')

            skip_inputs = tf.squeeze(inputs)[:, int(config.time_width / 2), :]
            skip_inputs = tf.squeeze(inputs, axis=3)

            inputs = tf.reshape(tf.squeeze(inputs, [3]), (-1, int(config.time_width * config.freq_size)))
            inputs = tf.nn.dropout(inputs, keep_prob=keep_prob)

            h1 = tf.nn.selu(utils.affine_transform(inputs, 1024, name='hidden_1'))
            h1 = tf.nn.dropout(h1, keep_prob=keep_prob)

            h2 = tf.nn.selu(utils.affine_transform(h1, 1024, name='hidden_2'))
            h2 = tf.nn.dropout(h2, keep_prob=keep_prob)

            h3 = tf.nn.selu(utils.affine_transform(h2, 1024, name='hidden_3'))
            h3 = tf.nn.dropout(h3, keep_prob=keep_prob)

            fm = utils.affine_transform(h3, int(config.freq_size * config.time_width), name='logits')
            fm = tf.reshape(fm, (-1, config.time_width, config.freq_size))

            pad = tf.zeros((1, config.freq_size * int(config.time_width / 2), config.time_width, 1))
            conv_fm = tf.reshape(tf.transpose(tf.expand_dims(fm, axis=3), [0, 2, 1, 3]),
                                 (1, -1, config.time_width, 1))
            conv_fm = tf.concat([pad, conv_fm, pad], axis=1)
            conv_fm = utils.extract_patch(tf.squeeze(conv_fm),
                                          patch_size=(config.freq_size * config.time_width, config.time_width))
            conv_fm = tf.stack(tf.split(conv_fm, num_or_size_splits=config.time_width, axis=1), axis=3)
            conv_fm = tf.reshape(conv_fm, (-1, config.freq_size, config.time_width * config.time_width))

            # att_inputs = tf.reshape(conv_fm, (-1, config.freq_size*config.time_width*config.time_width))
            # h4 = tf.nn.selu(utils.affine_transform(att_inputs, 1024, name='hidden_4'))
            # h4 = tf.nn.dropout(h4, keep_prob=keep_prob)
            # h5 = tf.nn.selu(utils.affine_transform(h4, config.time_width*config.time_width, name='hidden_5'))
            # att_outputs = tf.expand_dims(tf.nn.softmax(h5), axis=1)

            conv_fm = tf.concat([conv_fm, conv_inputs], axis=2)

            conv_fm = tf.expand_dims(conv_fm, axis=2)

            conv_1 = utils.conv_with_bn_2(conv_fm, 256, filter_size=[5, 1], stride=1, act='relu', scale=True,
                                          is_training=self._is_training, padding="SAME", name='conv_1')

            conv_2 = utils.conv_with_bn_2(conv_1, 128, filter_size=[5, 1], stride=1, act='relu', scale=True,
                                          is_training=self._is_training, padding="SAME", name='conv_2')

            conv_3 = utils.conv_with_bn_2(conv_2, 64, filter_size=[5, 1], stride=1, act='relu', scale=True,
                                          is_training=self._is_training, padding="SAME", name='conv_3')

            conv_4 = utils.conv_with_bn_2(conv_3, 32, filter_size=[5, 1], stride=1, act='relu', scale=True,
                                          is_training=self._is_training, padding="SAME", name='conv_4')

            conv_5 = utils.conv_with_bn_2(conv_4, 32, filter_size=[5, 1], stride=1, act='relu', scale=True,
                                          is_training=self._is_training, padding="SAME", name='conv_5')

            conv_6 = utils.conv_with_bn_2(conv_5, 32, filter_size=[5, 1], stride=1, act='relu', scale=True,
                                          is_training=self._is_training, padding="SAME", name='conv_6')
            conv_7 = utils.conv_with_bn_2(conv_6, 32, filter_size=[5, 1], stride=1, act='relu', scale=True,
                                          is_training=self._is_training, padding="SAME", name='conv_7')
            conv_8 = utils.conv_with_bn_2(conv_7, 1, filter_size=[5, 1], stride=1, act='relu', scale=False,
                                          is_training=self._is_training, padding="SAME", name='conv_8')
            conv_9 = tf.squeeze(tf.squeeze(conv_8, axis=2), axis=2)

        return fm, conv_9
예제 #4
0
def main():
    parser = argparse.ArgumentParser(
        description="Extract the 3D patches containing the nodules and store them in TFRecord files.",
    )
    parser.add_argument(
        "data_dir",
        help="Directory containing all the DCM files downloaded from https://wiki.cancerimagingarchive.net/display/Public/SPIE-AAPM+Lung+CT+Challenge",
    )
    parser.add_argument(
        "test_xls_file",
        help="Test Excel file obtained from https://wiki.cancerimagingarchive.net/display/Public/SPIE-AAPM+Lung+CT+Challenge",
    )
    # parser.add_argument(
    #    "calibration_xls_file",
    #    help="Calibration Excel file obtained from https://wiki.cancerimagingarchive.net/display/Public/SPIE-AAPM+Lung+CT+Challenge",
    # )

    args = parser.parse_args()

    data_dir = Path(args.data_dir)
    df = read_xls(args.test_xls_file)
    # test_df = read_xls(args.test_xls_file)
    # calibration_df = read_xls(args.calibration_xls_file)
    # df = pd.concat([test_df, calibration_df])
    assert len(df.index) == 73, "The input excels have not the expected size."

    with tf.io.TFRecordWriter(
        SPIE_SMALL_NEG_TFRECORD
    ) as small_neg_writer, tf.io.TFRecordWriter(
        SPIE_SMALL_POS_TFRECORD
    ) as small_pos_writer, tf.io.TFRecordWriter(
        SPIE_BIG_NEG_TFRECORD
    ) as big_neg_writer, tf.io.TFRecordWriter(
        SPIE_BIG_POS_TFRECORD
    ) as big_pos_writer:
        for row in tqdm(df.itertuples(), total=len(df.index)):
            dcm_dir_glob = Path(data_dir).glob(f"{row.scan_number}/*/*/")
            dcm_dir = list(dcm_dir_glob)[0]
            scan = read_dcm(dcm_dir, reverse_z=False)

            big_writer = big_pos_writer if row.label else big_neg_writer
            big_patch = extract_patch(
                scan,
                (row.zloc, row.yloc, row.xloc),
                BIG_PATCH_SHAPE[:-1],
            )
            big_patch = pad_to_shape(big_patch, BIG_PATCH_SHAPE[:-1])
            big_patch = np.expand_dims(big_patch, axis=-1)
            if not big_patch.any():
                print(
                    f"WARNING ({row.scan_number=}): "
                    "Patch contains only zeros. Skipping this patch ..."
                )
                continue
            assert (
                big_patch.shape == BIG_PATCH_SHAPE
            ), f"Wrong shape for scan {row.scan_number}."
            big_patch = big_patch.astype(np.float32)
            big_example = volume_to_example(big_patch)
            big_writer.write(big_example.SerializeToString())

            small_writer = small_pos_writer if row.label else small_neg_writer
            small_patch = extract_patch(
                scan,
                (row.zloc, row.yloc, row.xloc),
                SMALL_PATCH_SHAPE[:-1],
            )
            small_patch = pad_to_shape(small_patch, SMALL_PATCH_SHAPE[:-1])
            small_patch = np.expand_dims(small_patch, axis=-1)
            assert (
                small_patch.shape == SMALL_PATCH_SHAPE
            ), f"Wrong shape for scan {row.scan_number}."
            small_patch = small_patch.astype(np.float32)
            small_example = volume_to_example(small_patch)
            small_writer.write(small_example.SerializeToString())
예제 #5
0
def main():
    parser = argparse.ArgumentParser(
        description=
        "Extract the 3D patches containing the nodules and store them in TFRecord files.",
    )
    parser.add_argument(
        "data_dir",
        help=
        "Directory containing all the DCM files downloaded from https://wiki.cancerimagingarchive.net/display/Public/LIDC-IDRI",
    )
    parser.add_argument(
        "csv_file",
        help="CSV file obtained from http://www.via.cornell.edu/lidc",
    )

    args = parser.parse_args()

    data_dir = Path(args.data_dir)
    nodules_df = read_lidc_size_report(args.csv_file)
    assert (len(nodules_df.index) == 1387
            ), f"The input CSV {args.csv_file} has not the expected size."

    with tf.io.TFRecordWriter(
            LIDC_SMALL_NEG_TFRECORD) as small_neg_writer, tf.io.TFRecordWriter(
                LIDC_SMALL_POS_TFRECORD
            ) as small_pos_writer, tf.io.TFRecordWriter(
                LIDC_BIG_NEG_TFRECORD) as big_neg_writer, tf.io.TFRecordWriter(
                    LIDC_BIG_POS_TFRECORD
                ) as big_pos_writer, tf.io.TFRecordWriter(
                    LIDC_SMALL_UNLABELED_TFRECORD
                ) as small_unlabeled_writer, tf.io.TFRecordWriter(
                    LIDC_BIG_UNLABELED_TFRECORD) as big_unlabeled_writer:
        for row in tqdm(nodules_df.itertuples(), total=len(nodules_df.index)):
            case = row.case
            scan_id = row.scan
            dcm_dir_glob = list(
                data_dir.glob(f"LIDC-IDRI-{case}/*/{scan_id}.*/"))
            if len(dcm_dir_glob) == 0:
                print(f"WARNING ({scan_id=} {case=}): "
                      "Scan not found. Skipping this scan ...")
                continue
            if len(dcm_dir_glob) > 1:
                print(
                    f"WARNING ({scan_id=} {case=}): "
                    "Found multiple scans with same ids. Skipping this scan ..."
                )
                continue
            dcm_dir = dcm_dir_glob[0]
            scan = read_dcm(dcm_dir, reverse_z=True)
            xml_files = list(dcm_dir.glob("*.xml"))
            if len(xml_files) == 0:
                print(f"WARNING ({scan_id=} {case=}): "
                      "Can't find a XML file. Skipping this scan ...")
                continue
            elif len(xml_files) > 1:
                print(f"WARNING ({scan_id=} {case=}): "
                      "Found multiple XML files. Skipping this scan ...")
                continue
            xml_file = xml_files[0]
            nodule_ids = row.ids
            malignancies = get_malignancies(xml_file, nodule_ids)
            median_malignancy = median(malignancies)
            if median_malignancy < 3:
                big_writer = big_neg_writer
                small_writer = small_neg_writer
            elif median_malignancy > 3:
                big_writer = big_pos_writer
                small_writer = small_pos_writer
            else:
                # if the malignancies median is 3 then write the patch
                # as unlabeled
                big_writer = big_unlabeled_writer
                small_writer = small_unlabeled_writer

            big_patch = extract_patch(
                scan,
                (row.zloc, row.yloc, row.xloc),
                BIG_PATCH_SHAPE[:-1],
            )
            big_patch = pad_to_shape(big_patch, BIG_PATCH_SHAPE[:-1])
            big_patch = np.expand_dims(big_patch, axis=-1)
            if not big_patch.any():
                print(f"WARNING ({scan_id=} {case=}): "
                      "Patch contains only zeros. Skipping this patch ...")
                continue
            assert (big_patch.shape == BIG_PATCH_SHAPE
                    ), f"Wrong shape for scan {scan_id} in case {case}."
            big_patch = big_patch.astype(np.float32)
            big_example = volume_to_example(big_patch)
            big_writer.write(big_example.SerializeToString())

            small_patch = extract_patch(
                scan,
                (row.zloc, row.yloc, row.xloc),
                SMALL_PATCH_SHAPE[:-1],
            )
            small_patch = pad_to_shape(small_patch, SMALL_PATCH_SHAPE[:-1])
            small_patch = np.expand_dims(small_patch, axis=-1)
            assert (small_patch.shape == SMALL_PATCH_SHAPE
                    ), f"Wrong shape for scan {scan_id} in case {case}."
            small_patch = small_patch.astype(np.float32)
            small_example = volume_to_example(small_patch)
            small_writer.write(small_example.SerializeToString())