示例#1
0
def test_labelindex(labelmap_setup):
    dvid_server, dvid_repo, _merge_table_path, _mapping_path, _supervoxel_vol = labelmap_setup

    # Need an unlocked node to test these posts
    uuid = post_branch(dvid_server, dvid_repo, 'test_labelindex',
                       'test_labelindex')
    instance_info = (dvid_server, uuid, 'segmentation-scratch')

    # Write some random data
    sv = 99
    vol = sv * np.random.randint(2, size=(128, 128, 128), dtype=np.uint64)
    offset = np.array((64, 64, 64))

    # DVID will generate the index.
    post_labelmap_voxels(*instance_info, offset, vol)

    # Compute labelindex table from scratch
    rows = []
    for block_coord in ndrange(offset, offset + vol.shape, (64, 64, 64)):
        block_coord = np.array(block_coord)
        block_box = np.array((block_coord, block_coord + 64))
        block = extract_subvol(vol, block_box - offset)

        count = (block == sv).sum()
        rows.append([*block_coord, sv, count])

    index_df = pd.DataFrame(rows, columns=['z', 'y', 'x', 'sv', 'count'])

    # Check DVID's generated labelindex table against expected
    labelindex_tuple = fetch_labelindex(*instance_info, sv, format='pandas')
    assert labelindex_tuple.label == sv

    labelindex_tuple.blocks.sort_values(['z', 'y', 'x', 'sv'], inplace=True)
    labelindex_tuple.blocks.reset_index(drop=True, inplace=True)
    assert (labelindex_tuple.blocks == index_df).all().all()

    # Check our protobuf against DVID's
    index_tuple = PandasLabelIndex(index_df, sv, 1,
                                   datetime.datetime.now().isoformat(),
                                   'someuser')
    labelindex = create_labelindex(index_tuple)

    # Since labelindex block entries are not required to be sorted,
    # dvid might return them in a different order.
    # Hence this comparison function which sorts them first.
    def compare_proto_blocks(left, right):
        left_blocks = sorted(left.blocks.items())
        right_blocks = sorted(right.blocks.items())
        return left_blocks == right_blocks

    dvid_labelindex = fetch_labelindex(*instance_info, sv, format='protobuf')
    assert compare_proto_blocks(labelindex, dvid_labelindex)

    # Check post/get roundtrip
    post_labelindex(*instance_info, sv, labelindex)
    dvid_labelindex = fetch_labelindex(*instance_info, sv, format='protobuf')
    assert compare_proto_blocks(labelindex, dvid_labelindex)
示例#2
0
def test_fetch_labelindices(labelmap_setup):
    dvid_server, dvid_repo, _merge_table_path, _mapping_path, _supervoxel_vol = labelmap_setup

    # Need an unlocked node to test these posts
    uuid = post_branch(dvid_server, dvid_repo, 'test_labelindices',
                       'test_labelindices')
    instance_info = (dvid_server, uuid, 'segmentation-scratch')

    # Write some random data
    vol = np.random.randint(1, 10, size=(128, 128, 128), dtype=np.uint64)
    offset = np.array((64, 64, 64))

    # DVID will generate the index.
    post_labelmap_voxels(*instance_info, offset, vol)

    labelindices = fetch_labelindices(*instance_info, list(range(1, 10)))
    for sv, li in zip(range(1, 10), labelindices.indices):
        # This function is already tested elsewhere, so we'll use it as a reference
        li2 = fetch_labelindex(*instance_info, sv)
        assert li == li2

    labelindices = fetch_labelindices(*instance_info,
                                      list(range(1, 10)),
                                      format='list-of-protobuf')
    for sv, li in zip(range(1, 10), labelindices):
        # This function is already tested elsewhere, so we'll use it as a reference
        li2 = fetch_labelindex(*instance_info, sv)
        assert li == li2

    labelindices = fetch_labelindices(*instance_info,
                                      list(range(1, 10)),
                                      format='pandas')
    for sv, li in zip(range(1, 10), labelindices):
        # This function is already tested elsewhere, so we'll use it as a reference
        li2 = fetch_labelindex(*instance_info, sv, format='pandas')
        li_df = li.blocks.sort_values(['z', 'y', 'x']).reset_index(drop=True)
        li2_df = li2.blocks.sort_values(['z', 'y', 'x']).reset_index(drop=True)
        assert (li_df == li2_df).all().all()

    # Test the copy function (just do a round-trip -- hopefully I didn't swap src and dest anywhere...)
    copy_labelindices(instance_info,
                      instance_info,
                      list(range(1, 10)),
                      batch_size=2)
    copy_labelindices(instance_info,
                      instance_info,
                      list(range(1, 10)),
                      batch_size=2,
                      processes=2)
        def fetch_brick_coords(body, supervoxel_subset):
            """
            Fetch the block coordinates for the given body,
            filter them for the given supervoxels (if any),
            and convert the block coordinates to brick coordinates.
            """
            assert is_supervoxels or supervoxel_subset is None

            try:
                with mgr.access_context(server, True, 1, 1):
                    labelindex = fetch_labelindex(server, uuid, instance, body,
                                                  'protobuf')
                coords_df = convert_labelindex_to_pandas(labelindex).blocks

            except HTTPError as ex:
                if (ex.response is not None
                        and ex.response.status_code == 404):
                    return (body, None)
                raise
            except RuntimeError as ex:
                if 'does not map to any body' in str(ex):
                    return (body, None)
                raise

            if len(coords_df) == 0:
                return (body, None)

            if is_supervoxels:
                supervoxel_subset = set(supervoxel_subset)
                coords_df = coords_df.query('sv in @supervoxel_subset').copy()

            coords_df[['z', 'y', 'x']] //= brick_shape
            coords_df['body'] = np.uint64(body)
            coords_df.drop_duplicates(inplace=True)
            return (body, coords_df)
示例#4
0
def _repair_index(master_seg, body):
    pli = fetch_labelindex(*master_seg, body, format='pandas')

    # Just drop the blocks below coordinate 1024
    # (That's where the bad blocks were added, and
    # there isn't supposed to be segmentation in that region.)
    pli.blocks.query('z >= 1024 and y >= 1024 and x >= 1024', inplace=True)

    li = create_labelindex(pli)
    post_labelindex(*master_seg, pli.label, li)
    def process_batch(self, batch_and_rowcount):
        """
        Takes a batch of grouped stats rows and sends it to dvid in the appropriate protobuf format.
        
        If self.check_mismatches is True, read the labelindex for each 
        """
        next_stats_batch, next_stats_batch_total_rows = batch_and_rowcount
        labelindex_batch = chain(
            *map(self.label_indexes_for_body, next_stats_batch))

        if not self.check_mismatches:
            post_labelindex_batch(*self.instance_info, labelindex_batch)
            return next_stats_batch_total_rows, [], []

        # Check for mismatches
        mismatch_batch = []
        missing_batch = []
        for labelindex in labelindex_batch:
            try:
                existing_labelindex = fetch_labelindex(*self.instance_info,
                                                       labelindex.label)
            except requests.RequestException as ex:
                missing_batch.append(labelindex)
                if not str(ex.response.status_code).startswith('4'):
                    logger.warning(
                        f"Failed to fetch LabelIndex for label: {labelindex.label} due to error {ex.response.status_code}"
                    )
            else:
                if (labelindex.blocks != existing_labelindex.blocks):
                    # Update the mut_id to match the previous one.
                    labelindex.last_mutid = existing_labelindex.last_mutid
                    mismatch_batch.append(labelindex)

        # Post mismatches (only)
        post_labelindex_batch(*self.instance_info,
                              mismatch_batch + missing_batch)

        # Return mismatch IDs
        mismatch_labels = [labelindex.label for labelindex in mismatch_batch]
        missing_labels = [labelindex.label for labelindex in missing_batch]

        return next_stats_batch_total_rows, mismatch_labels, missing_labels
    def process_batch(self, batch_and_rowcount):
        """
        Takes a batch of grouped stats rows and sends it to dvid in the appropriate protobuf format.
        
        If self.check_mismatches is True, read the labelindex for each 
        """
        next_stats_batch, next_stats_batch_total_rows = batch_and_rowcount
        labelindex_batch = chain(*map(self.label_indexes_for_body, next_stats_batch))

        if not self.check_mismatches:
            self.post_labelindex_batch(labelindex_batch)
            return next_stats_batch_total_rows, [], []

        # Check for mismatches
        mismatch_batch = []
        missing_batch = []
        for labelindex in labelindex_batch:
            try:
                existing_labelindex = fetch_labelindex(*self.instance_info, labelindex.label, session=self.session)
            except requests.RequestException as ex:
                missing_batch.append(labelindex)
                if ex.response is None:
                    logger.warning(f"Failed to fetch LabelIndex for label: {labelindex.label} due to no response")
                elif not str(ex.response.status_code).startswith('4'):
                    logger.warning(f"Failed to fetch LabelIndex for label: {labelindex.label} due to error {ex.response.status_code}")
            else:
                if (labelindex.blocks != existing_labelindex.blocks):
                    # Update the mut_id to match the previous one.
                    labelindex.last_mutid = existing_labelindex.last_mutid
                    mismatch_batch.append(labelindex)

        # Post mismatches (only)
        self.post_labelindex_batch(mismatch_batch + missing_batch)

        # Return mismatch IDs
        mismatch_labels = [labelindex.label for labelindex in mismatch_batch]
        missing_labels = [labelindex.label for labelindex in missing_batch]
        
        return next_stats_batch_total_rows, mismatch_labels, missing_labels
def test_masksegmentation_basic(setup_dvid_segmentation_input, invert_mask,
                                roi_dilation, disable_auto_retry):
    template_dir, config, volume, dvid_address, repo_uuid, roi_mask_s5, input_segmentation_name, output_segmentation_name = setup_dvid_segmentation_input

    if invert_mask:
        roi_mask_s5 = ~roi_mask_s5

    config["masksegmentation"]["invert-mask"] = invert_mask
    config["masksegmentation"]["dilate-roi"] = roi_dilation

    # re-dump config
    yaml = YAML()
    yaml.default_flow_style = False
    with open(f"{template_dir}/workflow.yaml", 'w') as f:
        yaml.dump(config, f)

    execution_dir, workflow = launch_flow(template_dir, 1)
    final_config = workflow.config

    input_box_xyz = np.array(final_config['input']['geometry']['bounding-box'])
    input_box_zyx = input_box_xyz[:, ::-1]

    roi_mask = upsample(roi_mask_s5, 2**5)
    roi_mask = extract_subvol(roi_mask, input_box_zyx)

    expected_vol = extract_subvol(volume.copy(), input_box_zyx)
    expected_vol[roi_mask] = 0

    output_box_xyz = np.array(
        final_config['output']['geometry']['bounding-box'])
    output_box_zyx = output_box_xyz[:, ::-1]
    output_vol = fetch_labelmap_voxels(dvid_address,
                                       repo_uuid,
                                       output_segmentation_name,
                                       output_box_zyx,
                                       scale=0,
                                       supervoxels=True)

    # Create a copy of the volume that contains only the voxels we removed
    erased_vol = volume.copy()
    erased_vol[~roi_mask] = 0

    if EXPORT_DEBUG_FILES:
        original_vol = fetch_labelmap_voxels(dvid_address,
                                             repo_uuid,
                                             input_segmentation_name,
                                             output_box_zyx,
                                             scale=0,
                                             supervoxels=True)
        original_agglo_vol = fetch_labelmap_voxels(dvid_address,
                                                   repo_uuid,
                                                   input_segmentation_name,
                                                   output_box_zyx,
                                                   scale=0)
        output_agglo_vol = fetch_labelmap_voxels(dvid_address,
                                                 repo_uuid,
                                                 output_segmentation_name,
                                                 output_box_zyx,
                                                 scale=0)
        np.save('/tmp/original-svs.npy', original_vol)
        np.save('/tmp/original-agglo.npy', original_agglo_vol)
        np.save('/tmp/output.npy', output_vol)
        np.save('/tmp/output-agglo.npy', output_agglo_vol)
        np.save('/tmp/expected.npy', expected_vol)
        np.save('/tmp/erased.npy', erased_vol)

        shutil.copyfile(f'{execution_dir}/roi-mask.h5', '/tmp/roi-mask.h5')
        if roi_dilation:
            shutil.copyfile(f'{execution_dir}/dilated-roi-mask.h5',
                            '/tmp/dilated-roi-mask.h5')
        if invert_mask:
            shutil.copyfile(f'{execution_dir}/segmentation-mask.h5',
                            '/tmp/segmentation-mask.h5')
        shutil.copyfile(f'{execution_dir}/final-mask.h5', '/tmp/final-mask.h5')

    if roi_dilation > 0:
        # FIXME: We don't yet verify voxel-accuracy of ROI dilation.
        return

    assert (output_vol == expected_vol).all(), \
        "Written vol does not match expected"

    scaled_expected_vol = expected_vol
    for scale in range(1, 1 + MAX_SCALE):
        scaled_expected_vol = downsample(scaled_expected_vol, 2,
                                         'labels-numba')
        scaled_output_vol = fetch_labelmap_voxels(dvid_address,
                                                  repo_uuid,
                                                  output_segmentation_name,
                                                  output_box_zyx // 2**scale,
                                                  scale=scale,
                                                  supervoxels=True)

        if EXPORT_DEBUG_FILES:
            np.save(f'/tmp/expected-{scale}.npy', scaled_expected_vol)
            np.save(f'/tmp/expected-{scale}.npy', scaled_expected_vol)
            np.save(f'/tmp/output-{scale}.npy', scaled_output_vol)

        if scale <= 5:
            assert (scaled_output_vol == scaled_expected_vol).all(), \
                f"Written vol does not match expected at scale {scale}"
        else:
            # For scale 6 and 7, some blocks are not even changed,
            # but that means we would be comparing DVID's label
            # downsampling method to our method ('labels-numba').
            # The two don't necessarily give identical results in the case of 'ties',
            # so we'll just verify that the nonzero voxels match, at least.
            assert ((scaled_output_vol == 0) == (scaled_expected_vol == 0)).all(), \
                f"Written vol does not match expected at scale {scale}"

    block_stats_path = f'{execution_dir}/erased-block-statistics.h5'
    with h5py.File(block_stats_path, 'r') as f:
        stats_df = pd.DataFrame(f['stats'][:])

    #
    # Check the exported block statistics
    #
    stats_cols = [*BLOCK_STATS_DTYPES.keys()]
    assert stats_df.columns.tolist() == stats_cols
    stats_df = stats_df.sort_values(stats_cols).reset_index()

    expected_stats_df = block_stats_for_volume((64, 64, 64), erased_vol,
                                               input_box_zyx)
    expected_stats_df = expected_stats_df.sort_values(stats_cols).reset_index()

    assert len(stats_df) == len(expected_stats_df)
    assert (stats_df == expected_stats_df).all().all()

    #
    # Try updating the labelindexes
    #
    src_info = (dvid_address, repo_uuid, input_segmentation_name)
    dest_info = (dvid_address, repo_uuid, output_segmentation_name)
    with switch_cwd(execution_dir):
        erase_from_labelindexes(src_info,
                                dest_info,
                                block_stats_path,
                                batch_size=10,
                                threads=4)

    # Verify deleted supervoxels
    assert os.path.exists(f'{execution_dir}/deleted-supervoxels.csv')
    deleted_svs = set(
        pd.read_csv(f'{execution_dir}/deleted-supervoxels.csv')['sv'])

    orig_svs = {*pd.unique(volume.reshape(-1))} - {0}
    remaining_svs = {*pd.unique(expected_vol.reshape(-1))} - {0}
    expected_deleted_svs = orig_svs - remaining_svs
    assert deleted_svs == expected_deleted_svs

    # Verify remaining sizes
    expected_sv_counts = (pd.Series(
        expected_vol.reshape(-1),
        name='sv').value_counts().drop(0).sort_index().rename('count'))

    index_dfs = []
    for body in np.unique(fetch_mapping(*dest_info, remaining_svs)):
        index_df = fetch_labelindex(*dest_info, body, format='pandas').blocks
        index_dfs.append(index_df)

    sv_counts = (pd.concat(index_dfs, ignore_index=True)[[
        'sv', 'count'
    ]].groupby('sv')['count'].sum().sort_index())
    assert set(sv_counts.index.values) == set(expected_sv_counts.index.values)
    assert (sv_counts == expected_sv_counts).all(), \
        pd.DataFrame({'stored_count': sv_counts, 'expected_count': expected_sv_counts}).query('stored_count != expected_count')

    # Verify mapping
    # Deleted supervoxels exist in the mapping, but they map to 0.
    assert (fetch_mapping(*dest_info, [*deleted_svs]) == 0).all()

    # Remaining supervoxels still map to their original bodies
    assert (fetch_mapping(*dest_info, [*remaining_svs]) == fetch_mapping(
        *src_info, [*remaining_svs])).all()
示例#8
0
def main():
    configure_default_logging()
    
    parser = argparse.ArgumentParser()
    parser.add_argument('server')
    parser.add_argument('uuid')
    parser.add_argument('instance')
    parser.add_argument('block_stats')
    args = parser.parse_args()
    
    seg_instance = (args.server, args.uuid, args.instance)
    
    from flyemflows.bin.ingest_label_indexes import load_stats_h5_to_records
    with Timer("Loading block stats", logger):
        (block_sv_stats, _presorted_by, _agglo_path) = load_stats_h5_to_records('block-statistics.h5')
        stats_df = pd.DataFrame(block_sv_stats)
        stats_df = stats_df[['z', 'y', 'x', 'segment_id', 'count']]
        stats_df = stats_df.rename(columns={'segment_id': 'sv'})
        
        # Keep only the new supervoxels.
        stats_df = stats_df.query('sv > @NEW_SV_THRESHOLD').copy()
    
    with Timer("Fetching old labelindex", logger):
        labelindex = fetch_labelindex(*seg_instance, 106979579, format='protobuf')

    with Timer("Extracting labelindex table", logger):
        old_df = convert_labelindex_to_pandas(labelindex).blocks

    with Timer("Patching labelindex table", logger):
        # Discard old supervoxel stats within patched area
        in_patch  = (old_df[['z', 'y', 'x']].values >= PATCH_BOX_ZYX[0]).all(axis=1)
        in_patch &= (old_df[['z', 'y', 'x']].values  < PATCH_BOX_ZYX[1]).all(axis=1)
        
        old_df['in_patch'] = in_patch
        unpatched_df = old_df.query('not (in_patch and sv == @FRANKENBODY_SV)').copy()
        del unpatched_df['in_patch']
        
        # Append new stats
        new_df = pd.concat((unpatched_df, stats_df), ignore_index=True)
        new_df = new_df.sort_values(['z', 'y', 'x', 'sv'])

        np.save('old_df.npy', old_df.to_records(index=False))
        np.save('new_df.npy', new_df.to_records(index=False))

        if old_df['count'].sum() != new_df['count'].sum():
            logger.warning("Old and new indexes do not have the same total counts.  See old_df.npy and new_df.npy")

    with Timer("Constructing new labelindex", logger):    
        last_mutid = fetch_repo_info(*seg_instance[:2])["MutationID"]
        mod_time = datetime.datetime.now().isoformat()
        new_li = PandasLabelIndex(new_df, FRANKENBODY_SV, last_mutid, mod_time, os.environ.get("USER", "unknown"))
        new_labelindex = create_labelindex(new_li)

    with Timer("Posting new labelindex", logger):
        post_labelindex(*seg_instance, FRANKENBODY_SV, new_labelindex)

    with Timer("Posting updated mapping", logger):
        new_mapping = pd.Series(FRANKENBODY_SV, index=new_df['sv'].unique(), dtype=np.uint64, name='body')
        post_mappings(*seg_instance, new_mapping, last_mutid)

    logger.info("DONE")
    def process_batch(self, batch_and_rowcount):
        """
        Given a batch of ERASED block stats, fetches the existing LabelIndex,
        subtracts the erased stats, and posts either an updated labelindex or
        a tombstone (if the body is completely erased).
        """
        next_stats_batch, next_stats_batch_total_rows = batch_and_rowcount

        batch_indexes = []
        missing_bodies = []
        unexpected_dfs = []
        all_deleted_svs = []
        for body_group in next_stats_batch:
            body_id = body_group[0]['body_id']

            try:
                old_index = fetch_labelindex(*self.src_info,
                                             body_id,
                                             format='pandas')
            except requests.RequestException as ex:
                missing_bodies.append(body_id)
                if not str(ex.response.status_code).startswith('4'):
                    logger.warning(
                        f"Failed to fetch LabelIndex for label: {body_id} due to error {ex.response.status_code}"
                    )
                continue

            old_df = old_index.blocks
            erased_df = pd.DataFrame(body_group).rename(
                columns={'segment_id': 'sv'})[['z', 'y', 'x', 'sv', 'count']]
            assert erased_df.columns.tolist() == old_df.columns.tolist()
            assert old_df.duplicated(['z', 'y', 'x', 'sv']).sum() == 0
            assert erased_df.duplicated(['z', 'y', 'x', 'sv']).sum() == 0

            # Find the rows that exist on the old side (or both)
            merged_df = old_df.merge(erased_df,
                                     'outer',
                                     on=['z', 'y', 'x', 'sv'],
                                     suffixes=['_old', '_erased'],
                                     indicator='side')
            merged_df['count_old'] = merged_df['count_old'].fillna(0).astype(
                np.uint32)
            merged_df['count_erased'] = merged_df['count_erased'].fillna(
                0).astype(np.uint32)

            # If some supervoxel was "erased" from a particular block and the original
            # labelindex didn't mention it, that's a sign of corruption.
            # Save it for subsequent analysis
            unexpected_df = merged_df.query('count_old == 0').copy()
            if len(unexpected_df) > 0:
                unexpected_df['body'] = body_id
                unexpected_dfs.append(unexpected_df)

            merged_df = merged_df.query('count_old > 0').copy()
            merged_df[
                'count'] = merged_df['count_old'] - merged_df['count_erased']

            new_df = merged_df[['z', 'y', 'x', 'sv', 'count']]
            new_df = new_df.query('count > 0').copy()

            deleted_svs = set(old_df['sv']) - set(new_df['sv'])
            if deleted_svs:
                deleted_svs = np.fromiter(deleted_svs, dtype=np.uint64)
                all_deleted_svs.append(deleted_svs)

            if len(new_df) == 0:
                # Nothing to keep. Make a tombstone.
                tombstone_index = LabelIndex()
                tombstone_index.label = body_id
                tombstone_index.last_mutid = self.last_mutid
                tombstone_index.last_mod_user = self.user
                tombstone_index.last_mod_time = self.mod_time
                batch_indexes.append(tombstone_index)
            else:
                pli = PandasLabelIndex(new_df, body_id, self.last_mutid,
                                       self.mod_time, self.user)
                new_labelindex = create_labelindex(pli)
                batch_indexes.append(new_labelindex)

        # Write entire batch to DVID
        post_labelindex_batch(*self.dest_info, batch_indexes)

        # Return missing body IDs and the set of unexpected rows
        if unexpected_dfs:
            unexpected_df = pd.concat(unexpected_dfs)
        else:
            unexpected_df = None

        if all_deleted_svs:
            all_deleted_svs = np.concatenate(all_deleted_svs)

        return next_stats_batch_total_rows, missing_bodies, unexpected_df, all_deleted_svs