예제 #1
0
def get_snapshot_source(setup_config: Dict[str, Any],
                        source_samples: List[str]):
    snapshot = setup_config.get("SNAPSHOT_SOURCE", "snapshots/snapshot_1.hdf")

    # Data Properties
    voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"])

    # New array keys
    # Note: These are intended to be requested with size input_size
    raw = ArrayKey("RAW")
    consensus = gp.PointsKey("CONSENSUS")
    skeletonization = gp.PointsKey("SKELETONIZATION")
    matched = gp.PointsKey("MATCHED")
    nonempty_placeholder = gp.PointsKey("NONEMPTY")
    labels = ArrayKey("LABELS")

    data_sources = SnapshotSource(
        snapshot=snapshot,
        outputs={
            "volumes/raw": raw,
            "points/consensus": consensus,
            "points/skeletonization": skeletonization,
            "points/matched": matched,
            "points/matched": nonempty_placeholder,
            "points/labels": labels,
        },
        voxel_size=voxel_size,
    )

    return (
        data_sources,
        raw,
        labels,
        consensus,
        nonempty_placeholder,
        skeletonization,
        matched,
    )
예제 #2
0
    def test_csv_header(self):
        points = gp.PointsKey("POINTS")
        tswh = TracksSource(TEST_FILE_WITH_HEADER, points)

        request = gp.BatchRequest()
        request.add(points, gp.Coordinate((5, 5, 5, 5)))

        tswh.setup()
        b = tswh.provide(request)
        points = b[points].data
        self.assertListEqual([0.0, 0.0, 0.0, 0.0], list(points[1].location))
        self.assertListEqual([1.0, 0.0, 0.0, 0.0], list(points[2].location))
        self.assertListEqual([1.0, 1.0, 2.0, 3.0], list(points[3].location))
        self.assertListEqual([2.0, 2.0, 2.0, 2.0], list(points[4].location))
예제 #3
0
    def test_delete_points_in_context(self):
        points = gp.PointsKey("POINTS")
        pv_array = gp.ArrayKey("PARENT_VECTORS")
        mask = gp.ArrayKey("MASK")
        radius = [0.1, 0.1, 0.1, 0.1]
        ts = TracksSource(TEST_FILE, points)
        apv = AddParentVectors(points, pv_array, mask, radius)
        request = gp.BatchRequest()
        request.add(points, gp.Coordinate((1, 4, 4, 4)))
        request.add(pv_array, gp.Coordinate((1, 4, 4, 4)))
        request.add(mask, gp.Coordinate((1, 4, 4, 4)))

        pipeline = (ts + gp.Pad(points, None) + apv)
        with gp.build(pipeline):
            pipeline.request_batch(request)
예제 #4
0
    def test_pipeline3(self):
        array_key = gp.ArrayKey("TEST_ARRAY")
        points_key = gp.PointsKey("TEST_POINTS")
        voxel_size = gp.Coordinate((1, 1))
        spec = gp.ArraySpec(voxel_size=voxel_size, interpolatable=True)

        hdf5_source = gp.Hdf5Source(self.fake_data_file,
                                    {array_key: 'testdata'},
                                    array_specs={array_key: spec})
        csv_source = gp.CsvPointsSource(
            self.fake_points_file, points_key,
            gp.PointsSpec(
                roi=gp.Roi(shape=gp.Coordinate((100, 100)), offset=(0, 0))))

        request = gp.BatchRequest()
        shape = gp.Coordinate((60, 60))
        request.add(array_key, shape, voxel_size=gp.Coordinate((1, 1)))
        request.add(points_key, shape)

        shift_node = gp.ShiftAugment(prob_slip=0.2,
                                     prob_shift=0.2,
                                     sigma=5,
                                     shift_axis=0)
        pipeline = ((hdf5_source, csv_source) + gp.MergeProvider() +
                    gp.RandomLocation(ensure_nonempty=points_key) + shift_node)
        with gp.build(pipeline) as b:
            request = b.request_batch(request)
            # print(request[points_key])

        target_vals = [
            self.fake_data[point[0]][point[1]] for point in self.fake_points
        ]
        result_data = request[array_key].data
        result_points = request[points_key].data
        result_vals = [
            result_data[int(point.location[0])][int(point.location[1])]
            for point in result_points.values()
        ]

        for result_val in result_vals:
            self.assertTrue(
                result_val in target_vals,
                msg=
                "result value {} at points {} not in target values {} at points {}"
                .format(result_val, list(result_points.values()), target_vals,
                        self.fake_points))
예제 #5
0
    def test_add_parent_vectors(self):
        points = gp.PointsKey("POINTS")
        pv_array = gp.ArrayKey("PARENT_VECTORS")
        mask = gp.ArrayKey("MASK")
        radius = [0.1, 0.1, 0.1, 0.1]
        ts = TracksSource(TEST_FILE, points)
        apv = AddParentVectors(points, pv_array, mask, radius)
        request = gp.BatchRequest()
        request.add(points, gp.Coordinate((3, 4, 4, 4)))
        request.add(pv_array, gp.Coordinate((1, 4, 4, 4)))
        request.add(mask, gp.Coordinate((1, 4, 4, 4)))

        pipeline = (ts + gp.Pad(points, None) + apv)
        with gp.build(pipeline):
            batch = pipeline.request_batch(request)

        points = batch[points].data
        expected_mask = np.zeros(shape=(1, 4, 4, 4))
        expected_mask[0, 0, 0, 0] = 1
        expected_mask[0, 1, 2, 3] = 1

        expected_parent_vectors_z = np.zeros(shape=(1, 4, 4, 4))
        expected_parent_vectors_z[0, 1, 2, 3] = -1.0

        expected_parent_vectors_y = np.zeros(shape=(1, 4, 4, 4))
        expected_parent_vectors_y[0, 1, 2, 3] = -2.0

        expected_parent_vectors_x = np.zeros(shape=(1, 4, 4, 4))
        expected_parent_vectors_x[0, 1, 2, 3] = -3.0
        # print("MASK")
        # print(batch[mask].data)
        self.assertListEqual(expected_mask.tolist(), batch[mask].data.tolist())

        parent_vectors = batch[pv_array].data
        self.assertListEqual(expected_parent_vectors_z.tolist(),
                             parent_vectors[0].tolist())
        self.assertListEqual(expected_parent_vectors_y.tolist(),
                             parent_vectors[1].tolist())
        self.assertListEqual(expected_parent_vectors_x.tolist(),
                             parent_vectors[2].tolist())
예제 #6
0
def train_until(**kwargs):
    if tf.train.latest_checkpoint(kwargs['output_folder']):
        trained_until = int(
            tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1])
    else:
        trained_until = 0
    if trained_until >= kwargs['max_iteration']:
        return

    anchor = gp.ArrayKey('ANCHOR')
    raw = gp.ArrayKey('RAW')
    raw_cropped = gp.ArrayKey('RAW_CROPPED')
    gt_labels = gp.ArrayKey('GT_LABELS')
    gt_affs = gp.ArrayKey('GT_AFFS')
    gt_fgbg = gp.ArrayKey('GT_FGBG')
    gt_cpv = gp.ArrayKey('GT_CPV')
    gt_points = gp.PointsKey('GT_CPV_POINTS')

    loss_weights_affs = gp.ArrayKey('LOSS_WEIGHTS_AFFS')
    loss_weights_fgbg = gp.ArrayKey('LOSS_WEIGHTS_FGBG')
    # loss_weights_cpv = gp.ArrayKey('LOSS_WEIGHTS_CPV')

    pred_affs = gp.ArrayKey('PRED_AFFS')
    pred_fgbg = gp.ArrayKey('PRED_FGBG')
    pred_cpv = gp.ArrayKey('PRED_CPV')

    pred_affs_gradients = gp.ArrayKey('PRED_AFFS_GRADIENTS')
    pred_fgbg_gradients = gp.ArrayKey('PRED_FGBG_GRADIENTS')
    pred_cpv_gradients = gp.ArrayKey('PRED_CPV_GRADIENTS')

    with open(
            os.path.join(kwargs['output_folder'],
                         kwargs['name'] + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(
            os.path.join(kwargs['output_folder'],
                         kwargs['name'] + '_names.json'), 'r') as f:
        net_names = json.load(f)

    voxel_size = gp.Coordinate(kwargs['voxel_size'])
    input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size

    # formulate the request for what a batch should (at least) contain
    request = gp.BatchRequest()
    request.add(raw, input_shape_world)
    request.add(raw_cropped, output_shape_world)
    request.add(gt_labels, output_shape_world)
    request.add(gt_fgbg, output_shape_world)
    request.add(anchor, output_shape_world)
    request.add(gt_cpv, output_shape_world)
    request.add(gt_affs, output_shape_world)
    request.add(loss_weights_affs, output_shape_world)
    request.add(loss_weights_fgbg, output_shape_world)

    # when we make a snapshot for inspection (see below), we also want to
    # request the predicted affinities and gradients of the loss wrt the
    # affinities
    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw_cropped, output_shape_world)
    snapshot_request.add(pred_affs, output_shape_world)
    # snapshot_request.add(pred_affs_gradients, output_shape_world)
    snapshot_request.add(gt_fgbg, output_shape_world)
    snapshot_request.add(pred_fgbg, output_shape_world)
    # snapshot_request.add(pred_fgbg_gradients, output_shape_world)
    snapshot_request.add(pred_cpv, output_shape_world)
    # snapshot_request.add(pred_cpv_gradients, output_shape_world)

    if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr":
        raise NotImplementedError("train node for {} not implemented".format(
            kwargs['input_format']))

    fls = []
    shapes = []
    for f in kwargs['data_files']:
        fls.append(os.path.splitext(f)[0])
        if kwargs['input_format'] == "hdf":
            vol = h5py.File(f, 'r')['volumes/raw']
        elif kwargs['input_format'] == "zarr":
            vol = zarr.open(f, 'r')['volumes/raw']
        print(f, vol.shape, vol.dtype)
        shapes.append(vol.shape)
        if vol.dtype != np.float32:
            print("please convert to float32")
    ln = len(fls)
    print("first 5 files: ", fls[0:4])

    # padR = 46
    # padGT = 32

    if kwargs['input_format'] == "hdf":
        sourceNode = gp.Hdf5Source
    elif kwargs['input_format'] == "zarr":
        sourceNode = gp.ZarrSource

    augmentation = kwargs['augmentation']
    pipeline = (
        tuple(
            # read batches from the HDF5 file
            (
                sourceNode(
                    fls[t] + "." + kwargs['input_format'],
                    datasets={
                        raw: 'volumes/raw',
                        gt_labels: 'volumes/gt_labels',
                        gt_fgbg: 'volumes/gt_fgbg',
                        anchor: 'volumes/gt_fgbg',
                    },
                    array_specs={
                        raw: gp.ArraySpec(interpolatable=True),
                        gt_labels: gp.ArraySpec(interpolatable=False),
                        gt_fgbg: gp.ArraySpec(interpolatable=False),
                        anchor: gp.ArraySpec(interpolatable=False)
                    }
                ),
                gp.CsvIDPointsSource(
                    fls[t] + ".csv",
                    gt_points,
                    points_spec=gp.PointsSpec(roi=gp.Roi(
                        gp.Coordinate((0, 0, 0)),
                        gp.Coordinate(shapes[t])))
                )
            )
            + gp.MergeProvider()
            + gp.Pad(raw, None)
            + gp.Pad(gt_points, None)
            + gp.Pad(gt_labels, None)
            + gp.Pad(gt_fgbg, None)

            # chose a random location for each requested batch
            + gp.RandomLocation()

            for t in range(ln)
        ) +

        # chose a random source (i.e., sample) from the above
        gp.RandomProvider() +

        # elastically deform the batch
        (gp.ElasticAugment(
            augmentation['elastic']['control_point_spacing'],
            augmentation['elastic']['jitter_sigma'],
            [augmentation['elastic']['rotation_min']*np.pi/180.0,
             augmentation['elastic']['rotation_max']*np.pi/180.0],
            subsample=augmentation['elastic'].get('subsample', 1)) \
        if augmentation.get('elastic') is not None else NoOp())  +

        # apply transpose and mirror augmentations
        gp.SimpleAugment(mirror_only=augmentation['simple'].get("mirror"),
                         transpose_only=augmentation['simple'].get("transpose")) +

        # # scale and shift the intensity of the raw array
        gp.IntensityAugment(
            raw,
            scale_min=augmentation['intensity']['scale'][0],
            scale_max=augmentation['intensity']['scale'][1],
            shift_min=augmentation['intensity']['shift'][0],
            shift_max=augmentation['intensity']['shift'][1],
            z_section_wise=False) +

        # grow a boundary between labels
        gp.GrowBoundary(
            gt_labels,
            steps=1,
            only_xy=False) +

        # convert labels into affinities between voxels
        gp.AddAffinities(
            [[-1, 0, 0], [0, -1, 0], [0, 0, -1]],
            gt_labels,
            gt_affs) +

        gp.AddCPV(
            gt_points,
            gt_labels,
            gt_cpv) +
        # create a weight array that balances positive and negative samples in
        # the affinity array
        gp.BalanceLabels(
            gt_affs,
            loss_weights_affs) +

        gp.BalanceLabels(
            gt_fgbg,
            loss_weights_fgbg) +

        # pre-cache batches from the point upstream
        gp.PreCache(
            cache_size=kwargs['cache_size'],
            num_workers=kwargs['num_workers']) +

        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Train(
            os.path.join(kwargs['output_folder'], kwargs['name']),
            optimizer=net_names['optimizer'],
            summary=net_names['summaries'],
            log_dir=kwargs['output_folder'],
            loss=net_names['loss'],
            inputs={
                net_names['raw']: raw,
                net_names['gt_affs']: gt_affs,
                net_names['gt_fgbg']: gt_fgbg,
                net_names['anchor']: anchor,
                net_names['gt_cpv']: gt_cpv,
                net_names['gt_labels']: gt_labels,
                net_names['loss_weights_affs']: loss_weights_affs,
                net_names['loss_weights_fgbg']: loss_weights_fgbg
            },
            outputs={
                net_names['pred_affs']: pred_affs,
                net_names['pred_fgbg']: pred_fgbg,
                net_names['pred_cpv']: pred_cpv,
                net_names['raw_cropped']: raw_cropped,
            },
            gradients={
                net_names['pred_affs']: pred_affs_gradients,
                net_names['pred_fgbg']: pred_fgbg_gradients,
                net_names['pred_cpv']: pred_cpv_gradients
            },
            save_every=kwargs['checkpoints']) +

        # save the passing batch as an HDF5 file for inspection
        gp.Snapshot(
            {
                raw: '/volumes/raw',
                raw_cropped: 'volumes/raw_cropped',
                gt_labels: '/volumes/gt_labels',
                gt_affs: '/volumes/gt_affs',
                gt_fgbg: '/volumes/gt_fgbg',
                gt_cpv: '/volumes/gt_cpv',
                pred_affs: '/volumes/pred_affs',
                pred_affs_gradients: '/volumes/pred_affs_gradients',
                pred_fgbg: '/volumes/pred_fgbg',
                pred_fgbg_gradients: '/volumes/pred_fgbg_gradients',
                pred_cpv: '/volumes/pred_cpv',
                pred_cpv_gradients: '/volumes/pred_cpv_gradients'
            },
            output_dir=os.path.join(kwargs['output_folder'], 'snapshots'),
            output_filename='batch_{iteration}.hdf',
            every=kwargs['snapshots'],
            additional_request=snapshot_request,
            compression_type='gzip') +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=kwargs['profiling'])
    )

    #########
    # TRAIN #
    #########
    print("Starting training...")
    with gp.build(pipeline):
        print(pipeline)
        for i in range(trained_until, kwargs['max_iteration']):
            # print("request", request)
            start = time.time()
            pipeline.request_batch(request)
            time_of_iteration = time.time() - start

            logger.info("Batch: iteration=%d, time=%f", i, time_of_iteration)
            # exit()
    print("Training finished")
예제 #7
0
def train_until(**kwargs):
    if tf.train.latest_checkpoint(kwargs['output_folder']):
        trained_until = int(
            tf.train.latest_checkpoint(kwargs['output_folder']).split('_')[-1])
    else:
        trained_until = 0
    if trained_until >= kwargs['max_iteration']:
        return

    anchor = gp.ArrayKey('ANCHOR')
    raw = gp.ArrayKey('RAW')
    raw_cropped = gp.ArrayKey('RAW_CROPPED')

    points = gp.PointsKey('POINTS')
    gt_cp = gp.ArrayKey('GT_CP')
    pred_cp = gp.ArrayKey('PRED_CP')
    pred_cp_gradients = gp.ArrayKey('PRED_CP_GRADIENTS')

    with open(
            os.path.join(kwargs['output_folder'],
                         kwargs['name'] + '_config.json'), 'r') as f:
        net_config = json.load(f)
    with open(
            os.path.join(kwargs['output_folder'],
                         kwargs['name'] + '_names.json'), 'r') as f:
        net_names = json.load(f)

    voxel_size = gp.Coordinate(kwargs['voxel_size'])
    input_shape_world = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_shape_world = gp.Coordinate(net_config['output_shape']) * voxel_size

    # formulate the request for what a batch should (at least) contain
    request = gp.BatchRequest()
    request.add(raw, input_shape_world)
    request.add(raw_cropped, output_shape_world)
    request.add(gt_cp, output_shape_world)
    request.add(anchor, output_shape_world)

    # when we make a snapshot for inspection (see below), we also want to
    # request the predicted affinities and gradients of the loss wrt the
    # affinities
    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw_cropped, output_shape_world)
    snapshot_request.add(gt_cp, output_shape_world)
    snapshot_request.add(pred_cp, output_shape_world)
    # snapshot_request.add(pred_cp_gradients, output_shape_world)

    if kwargs['input_format'] != "hdf" and kwargs['input_format'] != "zarr":
        raise NotImplementedError("train node for %s not implemented yet",
                                  kwargs['input_format'])

    fls = []
    shapes = []
    mn = []
    mx = []
    for f in kwargs['data_files']:
        fls.append(os.path.splitext(f)[0])
        if kwargs['input_format'] == "hdf":
            vol = h5py.File(f, 'r')['volumes/raw']
        elif kwargs['input_format'] == "zarr":
            vol = zarr.open(f, 'r')['volumes/raw']
        print(f, vol.shape, vol.dtype)
        shapes.append(vol.shape)
        mn.append(np.min(vol))
        mx.append(np.max(vol))
        if vol.dtype != np.float32:
            print("please convert to float32")
    ln = len(fls)
    print("first 5 files: ", fls[0:4])

    if kwargs['input_format'] == "hdf":
        sourceNode = gp.Hdf5Source
    elif kwargs['input_format'] == "zarr":
        sourceNode = gp.ZarrSource

    augmentation = kwargs['augmentation']
    sources = tuple(
        (sourceNode(fls[t] + "." + kwargs['input_format'],
                    datasets={
                        raw: 'volumes/raw',
                        anchor: 'volumes/gt_fgbg',
                    },
                    array_specs={
                        raw: gp.ArraySpec(interpolatable=True),
                        anchor: gp.ArraySpec(interpolatable=False)
                    }),
         gp.CsvIDPointsSource(fls[t] + ".csv",
                              points,
                              points_spec=gp.PointsSpec(
                                  roi=gp.Roi(gp.Coordinate((
                                      0, 0, 0)), gp.Coordinate(shapes[t]))))) +
        gp.MergeProvider()
        # + Clip(raw, mn=mn[t], mx=mx[t])
        # + NormalizeMinMax(raw, mn=mn[t], mx=mx[t])
        + gp.Pad(raw, None) + gp.Pad(points, None)

        # chose a random location for each requested batch
        + gp.RandomLocation() for t in range(ln))
    pipeline = (
        sources +

        # chose a random source (i.e., sample) from the above
        gp.RandomProvider() +

       # elastically deform the batch
        (gp.ElasticAugment(
            augmentation['elastic']['control_point_spacing'],
            augmentation['elastic']['jitter_sigma'],
            [augmentation['elastic']['rotation_min']*np.pi/180.0,
             augmentation['elastic']['rotation_max']*np.pi/180.0],
            subsample=augmentation['elastic'].get('subsample', 1)) \
        if augmentation.get('elastic') is not None else NoOp())  +

        # apply transpose and mirror augmentations
        gp.SimpleAugment(mirror_only=augmentation['simple'].get("mirror"),
                         transpose_only=augmentation['simple'].get("transpose")) +
        # (gp.SimpleAugment(
        #     mirror_only=augmentation['simple'].get("mirror"),
        #     transpose_only=augmentation['simple'].get("transpose")) \
        # if augmentation.get('simple') is not None and \
        #    augmentation.get('simple') != {} else NoOp())  +

        # # scale and shift the intensity of the raw array
        (gp.IntensityAugment(
            raw,
            scale_min=augmentation['intensity']['scale'][0],
            scale_max=augmentation['intensity']['scale'][1],
            shift_min=augmentation['intensity']['shift'][0],
            shift_max=augmentation['intensity']['shift'][1],
            z_section_wise=False) \
        if augmentation.get('intensity') is not None and \
           augmentation.get('intensity') != {} else NoOp())  +

        gp.RasterizePoints(
            points,
            gt_cp,
            array_spec=gp.ArraySpec(voxel_size=voxel_size),
            settings=gp.RasterizationSettings(
                radius=(2, 2, 2),
                mode='peak')) +

        # pre-cache batches from the point upstream
        gp.PreCache(
            cache_size=kwargs['cache_size'],
            num_workers=kwargs['num_workers']) +

        # perform one training iteration for each passing batch (here we use
        # the tensor names earlier stored in train_net.config)
        gp.tensorflow.Train(
            os.path.join(kwargs['output_folder'], kwargs['name']),
            optimizer=net_names['optimizer'],
            summary=net_names['summaries'],
            log_dir=kwargs['output_folder'],
            loss=net_names['loss'],
            inputs={
                net_names['raw']: raw,
                net_names['gt_cp']: gt_cp,
                net_names['anchor']: anchor,
            },
            outputs={
                net_names['pred_cp']: pred_cp,
                net_names['raw_cropped']: raw_cropped,
            },
            gradients={
                # net_names['pred_cp']: pred_cp_gradients,
            },
            save_every=kwargs['checkpoints']) +

        # save the passing batch as an HDF5 file for inspection
        gp.Snapshot(
            {
                raw: '/volumes/raw',
                raw_cropped: 'volumes/raw_cropped',
                gt_cp: '/volumes/gt_cp',
                pred_cp: '/volumes/pred_cp',
                # pred_cp_gradients: '/volumes/pred_cp_gradients',
            },
            output_dir=os.path.join(kwargs['output_folder'], 'snapshots'),
            output_filename='batch_{iteration}.hdf',
            every=kwargs['snapshots'],
            additional_request=snapshot_request,
            compression_type='gzip') +

        # show a summary of time spend in each node every 10 iterations
        gp.PrintProfilingStats(every=kwargs['profiling'])
    )

    #########
    # TRAIN #
    #########
    print("Starting training...")
    with gp.build(pipeline):
        print(pipeline)
        for i in range(trained_until, kwargs['max_iteration']):
            # print("request", request)
            start = time.time()
            pipeline.request_batch(request)
            time_of_iteration = time.time() - start

            logger.info("Batch: iteration=%d, time=%f", i, time_of_iteration)
            # exit()
    print("Training finished")
예제 #8
0
파일: train.py 프로젝트: funkelab/synful
def build_pipeline(parameter, augment=True):
    voxel_size = gp.Coordinate(parameter['voxel_size'])

    # Array Specifications.
    raw = gp.ArrayKey('RAW')
    gt_neurons = gp.ArrayKey('GT_NEURONS')
    gt_postpre_vectors = gp.ArrayKey('GT_POSTPRE_VECTORS')
    gt_post_indicator = gp.ArrayKey('GT_POST_INDICATOR')
    post_loss_weight = gp.ArrayKey('POST_LOSS_WEIGHT')
    vectors_mask = gp.ArrayKey('VECTORS_MASK')

    pred_postpre_vectors = gp.ArrayKey('PRED_POSTPRE_VECTORS')
    pred_post_indicator = gp.ArrayKey('PRED_POST_INDICATOR')

    grad_syn_indicator = gp.ArrayKey('GRAD_SYN_INDICATOR')
    grad_partner_vectors = gp.ArrayKey('GRAD_PARTNER_VECTORS')

    # Points specifications
    dummypostsyn = gp.PointsKey('DUMMYPOSTSYN')
    postsyn = gp.PointsKey('POSTSYN')
    presyn = gp.PointsKey('PRESYN')
    trg_context = 140  # AddPartnerVectorMap context in nm - pre-post distance

    with open('train_net_config.json', 'r') as f:
        net_config = json.load(f)

    input_size = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_size = gp.Coordinate(net_config['output_shape']) * voxel_size

    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(gt_neurons, output_size)
    request.add(gt_postpre_vectors, output_size)
    request.add(gt_post_indicator, output_size)
    request.add(post_loss_weight, output_size)
    request.add(vectors_mask, output_size)
    request.add(dummypostsyn, output_size)

    for (key, request_spec) in request.items():
        print(key)
        print(request_spec.roi)
        request_spec.roi.contains(request_spec.roi)
    # slkfdms

    snapshot_request = gp.BatchRequest({
        pred_post_indicator:
        request[gt_postpre_vectors],
        pred_postpre_vectors:
        request[gt_postpre_vectors],
        grad_syn_indicator:
        request[gt_postpre_vectors],
        grad_partner_vectors:
        request[gt_postpre_vectors],
        vectors_mask:
        request[gt_postpre_vectors]
    })

    postsyn_rastersetting = gp.RasterizationSettings(
        radius=parameter['blob_radius'],
        mask=gt_neurons,
        mode=parameter['blob_mode'])

    pipeline = tuple([
        create_source(sample, raw, presyn, postsyn, dummypostsyn, parameter,
                      gt_neurons) for sample in samples
    ])

    pipeline += gp.RandomProvider()
    if augment:
        pipeline += gp.ElasticAugment([4, 40, 40], [0, 2, 2],
                                      [0, math.pi / 2.0],
                                      prob_slip=0.05,
                                      prob_shift=0.05,
                                      max_misalign=10,
                                      subsample=8)
        pipeline += gp.SimpleAugment(transpose_only=[1, 2], mirror_only=[1, 2])
        pipeline += gp.IntensityAugment(raw,
                                        0.9,
                                        1.1,
                                        -0.1,
                                        0.1,
                                        z_section_wise=True)
    pipeline += gp.IntensityScaleShift(raw, 2, -1)
    pipeline += gp.RasterizePoints(
        postsyn, gt_post_indicator,
        gp.ArraySpec(voxel_size=voxel_size, dtype=np.int32),
        postsyn_rastersetting)
    spec = gp.ArraySpec(voxel_size=voxel_size)
    pipeline += AddPartnerVectorMap(
        src_points=postsyn,
        trg_points=presyn,
        array=gt_postpre_vectors,
        radius=parameter['d_blob_radius'],
        trg_context=trg_context,  # enlarge
        array_spec=spec,
        mask=gt_neurons,
        pointmask=vectors_mask)
    pipeline += gp.BalanceLabels(labels=gt_post_indicator,
                                 scales=post_loss_weight,
                                 slab=(-1, -1, -1),
                                 clipmin=parameter['cliprange'][0],
                                 clipmax=parameter['cliprange'][1])
    if parameter['d_scale'] != 1:
        pipeline += gp.IntensityScaleShift(gt_postpre_vectors,
                                           scale=parameter['d_scale'],
                                           shift=0)
    pipeline += gp.PreCache(cache_size=40, num_workers=10)
    pipeline += gp.tensorflow.Train(
        './train_net',
        optimizer=net_config['optimizer'],
        loss=net_config['loss'],
        summary=net_config['summary'],
        log_dir='./tensorboard/',
        save_every=30000,  # 10000
        log_every=100,
        inputs={
            net_config['raw']:
            raw,
            net_config['gt_partner_vectors']:
            gt_postpre_vectors,
            net_config['gt_syn_indicator']:
            gt_post_indicator,
            net_config['vectors_mask']:
            vectors_mask,
            # Loss weights --> mask
            net_config['indicator_weight']:
            post_loss_weight,  # Loss weights
        },
        outputs={
            net_config['pred_partner_vectors']: pred_postpre_vectors,
            net_config['pred_syn_indicator']: pred_post_indicator,
        },
        gradients={
            net_config['pred_partner_vectors']: grad_partner_vectors,
            net_config['pred_syn_indicator']: grad_syn_indicator,
        },
    )
    # Visualize.
    pipeline += gp.IntensityScaleShift(raw, 0.5, 0.5)
    pipeline += gp.Snapshot(
        {
            raw: 'volumes/raw',
            gt_neurons: 'volumes/labels/neuron_ids',
            gt_post_indicator: 'volumes/gt_post_indicator',
            gt_postpre_vectors: 'volumes/gt_postpre_vectors',
            pred_postpre_vectors: 'volumes/pred_postpre_vectors',
            pred_post_indicator: 'volumes/pred_post_indicator',
            post_loss_weight: 'volumes/post_loss_weight',
            grad_syn_indicator: 'volumes/post_indicator_gradients',
            grad_partner_vectors: 'volumes/partner_vectors_gradients',
            vectors_mask: 'volumes/vectors_mask'
        },
        every=1000,
        output_filename='batch_{iteration}.hdf',
        compression_type='gzip',
        additional_request=snapshot_request)
    pipeline += gp.PrintProfilingStats(every=100)

    print("Starting training...")
    max_iteration = parameter['max_iteration']
    with gp.build(pipeline) as b:
        for i in range(max_iteration):
            b.request_batch(request)
예제 #9
0
    "score_thr": 0,
    "score_type": "sum",
    "nms_radius": None
}

parameters = synful.detection.SynapseExtractionParameters(
    extract_type=parameter_dic['extract_type'],
    cc_threshold=parameter_dic['cc_threshold'],
    loc_type=parameter_dic['loc_type'],
    score_thr=parameter_dic['score_thr'],
    score_type=parameter_dic['score_type'],
    nms_radius=parameter_dic['nms_radius'])

gp.ArrayKey('M_PRED')
gp.ArrayKey('D_PRED')
gp.PointsKey('PRESYN')
gp.PointsKey('POSTSYN')


class TestSource(gp.BatchProvider):
    def __init__(self, m_pred, d_pred, voxel_size):
        self.voxel_size = voxel_size
        self.m_pred = m_pred
        self.d_pred = d_pred

    def setup(self):
        self.provides(
            gp.ArrayKeys.M_PRED,
            gp.ArraySpec(roi=gp.Roi((0, 0, 0), (200, 200, 200)),
                         voxel_size=self.voxel_size,
                         interpolatable=False))
예제 #10
0
def train_until(max_iteration):

    # get the latest checkpoint
    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    # array keys for fused volume
    raw = gp.ArrayKey('RAW')
    labels = gp.ArrayKey('LABELS')
    labels_fg = gp.ArrayKey('LABELS_FG')

    # array keys for base volume
    raw_base = gp.ArrayKey('RAW_BASE')
    labels_base = gp.ArrayKey('LABELS_BASE')
    swc_base = gp.PointsKey('SWC_BASE')
    swc_center_base = gp.PointsKey('SWC_CENTER_BASE')

    # array keys for add volume
    raw_add = gp.ArrayKey('RAW_ADD')
    labels_add = gp.ArrayKey('LABELS_ADD')
    swc_add = gp.PointsKey('SWC_ADD')
    swc_center_add = gp.PointsKey('SWC_CENTER_ADD')

    # output data
    fg = gp.ArrayKey('FG')
    gradient_fg = gp.ArrayKey('GRADIENT_FG')
    loss_weights = gp.ArrayKey('LOSS_WEIGHTS')

    voxel_size = gp.Coordinate((3, 3, 3))
    input_size = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_size = gp.Coordinate(net_config['output_shape']) * voxel_size

    # add request
    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(labels, output_size)
    request.add(labels_fg, output_size)
    request.add(loss_weights, output_size)

    request.add(swc_center_base, output_size)
    request.add(swc_base, input_size)

    request.add(swc_center_add, output_size)
    request.add(swc_add, input_size)

    # add snapshot request
    snapshot_request = gp.BatchRequest()
    snapshot_request.add(fg, output_size)
    snapshot_request.add(labels_fg, output_size)
    snapshot_request.add(gradient_fg, output_size)
    snapshot_request.add(raw_base, input_size)
    snapshot_request.add(raw_add, input_size)
    snapshot_request.add(labels_base, input_size)
    snapshot_request.add(labels_add, input_size)

    # data source for "base" volume
    data_sources_base = tuple()
    data_sources_base += tuple(
        (gp.Hdf5Source(file,
                       datasets={
                           raw_base: '/volume',
                       },
                       array_specs={
                           raw_base:
                           gp.ArraySpec(interpolatable=True,
                                        voxel_size=voxel_size,
                                        dtype=np.uint16),
                       },
                       channels_first=False),
         SwcSource(filename=file,
                   dataset='/reconstruction',
                   points=(swc_center_base, swc_base),
                   scale=voxel_size)) + gp.MergeProvider() +
        gp.RandomLocation(ensure_nonempty=swc_center_base) + RasterizeSkeleton(
            points=swc_base,
            array=labels_base,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
            iteration=10) for file in files)
    data_sources_base += gp.RandomProvider()

    # data source for "add" volume
    data_sources_add = tuple()
    data_sources_add += tuple(
        (gp.Hdf5Source(file,
                       datasets={
                           raw_add: '/volume',
                       },
                       array_specs={
                           raw_add:
                           gp.ArraySpec(interpolatable=True,
                                        voxel_size=voxel_size,
                                        dtype=np.uint16),
                       },
                       channels_first=False),
         SwcSource(filename=file,
                   dataset='/reconstruction',
                   points=(swc_center_add, swc_add),
                   scale=voxel_size)) + gp.MergeProvider() +
        gp.RandomLocation(ensure_nonempty=swc_center_add) + RasterizeSkeleton(
            points=swc_add,
            array=labels_add,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
            iteration=1) for file in files)
    data_sources_add += gp.RandomProvider()
    data_sources = tuple([data_sources_base, data_sources_add
                          ]) + gp.MergeProvider()

    pipeline = (
        data_sources + FusionAugment(raw_base,
                                     raw_add,
                                     labels_base,
                                     labels_add,
                                     raw,
                                     labels,
                                     blend_mode='labels_mask',
                                     blend_smoothness=10,
                                     num_blended_objects=0) +

        # augment
        gp.ElasticAugment([10, 10, 10], [1, 1, 1], [0, math.pi / 2.0],
                          subsample=8) +
        gp.SimpleAugment(mirror_only=[2], transpose_only=[]) +
        gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001) +
        BinarizeGt(labels, labels_fg) +
        gp.BalanceLabels(labels_fg, loss_weights) +

        # train
        gp.PreCache(cache_size=40, num_workers=10) +
        gp.tensorflow.Train('./train_net',
                            optimizer=net_names['optimizer'],
                            loss=net_names['loss'],
                            inputs={
                                net_names['raw']: raw,
                                net_names['labels_fg']: labels_fg,
                                net_names['loss_weights']: loss_weights,
                            },
                            outputs={
                                net_names['fg']: fg,
                            },
                            gradients={
                                net_names['fg']: gradient_fg,
                            },
                            save_every=100) +

        # visualize
        gp.Snapshot(output_filename='snapshot_{iteration}.hdf',
                    dataset_names={
                        raw: 'volumes/raw',
                        raw_base: 'volumes/raw_base',
                        raw_add: 'volumes/raw_add',
                        labels: 'volumes/labels',
                        labels_base: 'volumes/labels_base',
                        labels_add: 'volumes/labels_add',
                        fg: 'volumes/fg',
                        labels_fg: 'volumes/labels_fg',
                        gradient_fg: 'volumes/gradient_fg',
                    },
                    additional_request=snapshot_request,
                    every=10) + gp.PrintProfilingStats(every=100))

    with gp.build(pipeline):

        print("Starting training...")
        for i in range(max_iteration - trained_until):
            pipeline.request_batch(request)
예제 #11
0
def train_until(max_iteration):

    # get the latest checkpoint
    if tf.train.latest_checkpoint("."):
        trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    # array keys for data sources
    raw = gp.ArrayKey("RAW")
    swcs = gp.PointsKey("SWCS")
    labels = gp.ArrayKey("LABELS")

    # array keys for base volume
    raw_base = gp.ArrayKey("RAW_BASE")
    labels_base = gp.ArrayKey("LABELS_BASE")
    swc_base = gp.PointsKey("SWC_BASE")

    # array keys for add volume
    raw_add = gp.ArrayKey("RAW_ADD")
    labels_add = gp.ArrayKey("LABELS_ADD")
    swc_add = gp.PointsKey("SWC_ADD")

    # array keys for fused volume
    raw_fused = gp.ArrayKey("RAW_FUSED")
    labels_fused = gp.ArrayKey("LABELS_FUSED")
    swc_fused = gp.PointsKey("SWC_FUSED")

    # output data
    fg = gp.ArrayKey("FG")
    labels_fg = gp.ArrayKey("LABELS_FG")
    labels_fg_bin = gp.ArrayKey("LABELS_FG_BIN")

    gradient_fg = gp.ArrayKey("GRADIENT_FG")
    loss_weights = gp.ArrayKey("LOSS_WEIGHTS")

    voxel_size = gp.Coordinate((10, 3, 3))
    input_size = gp.Coordinate(net_config["input_shape"]) * voxel_size
    output_size = gp.Coordinate(net_config["output_shape"]) * voxel_size

    # add request
    request = gp.BatchRequest()
    request.add(raw_fused, input_size)
    request.add(labels_fused, input_size)
    request.add(swc_fused, input_size)
    request.add(labels_fg, output_size)
    request.add(labels_fg_bin, output_size)
    request.add(loss_weights, output_size)

    # add snapshot request
    # request.add(fg, output_size)
    # request.add(labels_fg, output_size)
    request.add(gradient_fg, output_size)
    request.add(raw_base, input_size)
    request.add(raw_add, input_size)
    request.add(labels_base, input_size)
    request.add(labels_add, input_size)
    request.add(swc_base, input_size)
    request.add(swc_add, input_size)

    data_sources = tuple(
        (
            gp.N5Source(
                filename=str(
                    (
                        filename
                        / "consensus-neurons-with-machine-centerpoints-labelled-as-swcs-carved.n5"
                    ).absolute()
                ),
                datasets={raw: "volume"},
                array_specs={
                    raw: gp.ArraySpec(
                        interpolatable=True, voxel_size=voxel_size, dtype=np.uint16
                    )
                },
            ),
            MouselightSwcFileSource(
                filename=str(
                    (
                        filename
                        / "consensus-neurons-with-machine-centerpoints-labelled-as-swcs"
                    ).absolute()
                ),
                points=(swcs,),
                scale=voxel_size,
                transpose=(2, 1, 0),
                transform_file=str((filename / "transform.txt").absolute()),
                ignore_human_nodes=True
            ),
        )
        + gp.MergeProvider()
        + gp.RandomLocation(
            ensure_nonempty=swcs, ensure_centered=True
        )
        + RasterizeSkeleton(
            points=swcs,
            array=labels,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32
            ),
        )
        + GrowLabels(labels, radius=10)
        # augment
        + gp.ElasticAugment(
            [40, 10, 10],
            [0.25, 1, 1],
            [0, math.pi / 2.0],
            subsample=4,
        )
        + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2])
        + gp.Normalize(raw)
        + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001)
        for filename in Path(sample_dir).iterdir()
        if "2018-08-01" in filename.name  # or "2018-07-02" in filename.name
    )

    pipeline = (
        data_sources
        + gp.RandomProvider()
        + GetNeuronPair(
            swcs,
            raw,
            labels,
            (swc_base, swc_add),
            (raw_base, raw_add),
            (labels_base, labels_add),
            seperate_by=150,
            shift_attempts=50,
            request_attempts=10,
        )
        + FusionAugment(
            raw_base,
            raw_add,
            labels_base,
            labels_add,
            swc_base,
            swc_add,
            raw_fused,
            labels_fused,
            swc_fused,
            blend_mode="labels_mask",
            blend_smoothness=10,
            num_blended_objects=0,
        )
        + Crop(labels_fused, labels_fg)
        + BinarizeGt(labels_fg, labels_fg_bin)
        + gp.BalanceLabels(labels_fg_bin, loss_weights)
        # train
        + gp.PreCache(cache_size=40, num_workers=10)
        + gp.tensorflow.Train(
            "./train_net",
            optimizer=net_names["optimizer"],
            loss=net_names["loss"],
            inputs={
                net_names["raw"]: raw_fused,
                net_names["labels_fg"]: labels_fg_bin,
                net_names["loss_weights"]: loss_weights,
            },
            outputs={net_names["fg"]: fg},
            gradients={net_names["fg"]: gradient_fg},
            save_every=100000,
        )
        + gp.Snapshot(
            output_filename="snapshot_{iteration}.hdf",
            dataset_names={
                raw_fused: "volumes/raw_fused",
                raw_base: "volumes/raw_base",
                raw_add: "volumes/raw_add",
                labels_fused: "volumes/labels_fused",
                labels_base: "volumes/labels_base",
                labels_add: "volumes/labels_add",
                labels_fg_bin: "volumes/labels_fg_bin",
                fg: "volumes/pred_fg",
                gradient_fg: "volumes/gradient_fg",
            },
            every=100,
        )
        + gp.PrintProfilingStats(every=10)
    )

    with gp.build(pipeline):

        logging.info("Starting training...")
        for i in range(max_iteration - trained_until):
            logging.info("requesting batch {}".format(i))
            batch = pipeline.request_batch(request)
            """
예제 #12
0
def get_neuron_pair(
    pipeline,
    setup_config,
    raw: ArrayKey,
    labels: ArrayKey,
    matched: PointsKey,
    nonempty_placeholder: PointsKey,
):

    # Data Properties
    output_shape = gp.Coordinate(setup_config["OUTPUT_SHAPE"])
    voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"])
    output_size = output_shape * voxel_size
    micron_scale = voxel_size[0]

    # Somewhat arbitrary hyperparameters
    blend_mode = setup_config["BLEND_MODE"]
    shift_attempts = setup_config["SHIFT_ATTEMPTS"]
    request_attempts = setup_config["REQUEST_ATTEMPTS"]
    blend_smoothness = setup_config["BLEND_SMOOTHNESS"]
    seperate_by = setup_config["SEPERATE_BY"]
    seperate_distance = (np.array(seperate_by)).tolist()

    # array keys for fused volume
    raw_fused = ArrayKey("RAW_FUSED")
    labels_fused = ArrayKey("LABELS_FUSED")
    matched_fused = gp.PointsKey("MATCHED_FUSED")

    # array keys for base volume
    raw_base = ArrayKey("RAW_BASE")
    labels_base = ArrayKey("LABELS_BASE")
    matched_base = gp.PointsKey("MATCHED_BASE")

    # array keys for add volume
    raw_add = ArrayKey("RAW_ADD")
    labels_add = ArrayKey("LABELS_ADD")
    matched_add = gp.PointsKey("MATCHED_ADD")

    # debug array keys
    soft_mask = gp.ArrayKey("SOFT_MASK")
    masked_base = gp.ArrayKey("MASKED_BASE")
    masked_add = gp.ArrayKey("MASKED_ADD")
    mask_maxed = gp.ArrayKey("MASK_MAXED")

    pipeline = pipeline + GetNeuronPair(
        matched,
        raw,
        labels,
        (matched_base, matched_add),
        (raw_base, raw_add),
        (labels_base, labels_add),
        output_shape=output_size,
        seperate_by=seperate_distance,
        shift_attempts=shift_attempts,
        request_attempts=request_attempts,
        # nonempty_placeholder=nonempty_placeholder,
        nonempty_placeholder=nonempty_placeholder,
    )
    if blend_mode == "add":
        if setup_config["PRE_CLAHE"]:
            pipeline = pipeline + scipyCLAHE(
                [raw_add, raw_base],
                gp.Coordinate([20, 64, 64]) * voxel_size,
                clip_limit=float(setup_config["CLIP_LIMIT"]),
                normalize=setup_config["CLAHE_NORMALIZE"],
            )
    pipeline = pipeline + FusionAugment(
        raw_base,
        raw_add,
        labels_base,
        labels_add,
        matched_base,
        matched_add,
        raw_fused,
        labels_fused,
        matched_fused,
        soft_mask=soft_mask,
        masked_base=masked_base,
        masked_add=masked_add,
        mask_maxed=mask_maxed,
        blend_mode=blend_mode,
        blend_smoothness=blend_smoothness * micron_scale,
        gaussian_smooth_mode="mirror",  # TODO: Config this
        num_blended_objects=0,  # TODO: Config this
    )
    if blend_mode == "add":
        if setup_config["POST_CLAHE"]:
            pipeline = pipeline + scipyCLAHE(
                [raw_add, raw_base],
                gp.Coordinate([20, 64, 64]) * voxel_size,
                clip_limit=float(setup_config["CLIP_LIMIT"]),
                normalize=setup_config["CLAHE_NORMALIZE"],
            )

    return (
        pipeline,
        raw_fused,
        labels_fused,
        matched_fused,
        raw_base,
        labels_base,
        matched_base,
        raw_add,
        labels_add,
        matched_add,
        soft_mask,
        masked_base,
        masked_add,
        mask_maxed,
    )
예제 #13
0
def get_mouselight_data_sources(setup_config: Dict[str, Any],
                                source_samples: List[str],
                                locations=False):
    # Source Paths and accessibility
    raw_n5 = setup_config["RAW_N5"]
    mongo_url = setup_config["MONGO_URL"]
    samples_path = Path(setup_config["SAMPLES_PATH"])

    # specified_locations = setup_config.get("SPECIFIED_LOCATIONS")

    # Graph matching parameters
    point_balance_radius = setup_config["POINT_BALANCE_RADIUS"]
    matching_failures_dir = setup_config["MATCHING_FAILURES_DIR"]
    matching_failures_dir = (matching_failures_dir
                             if matching_failures_dir is None else
                             Path(matching_failures_dir))

    # Data Properties
    voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"])
    output_shape = gp.Coordinate(setup_config["OUTPUT_SHAPE"])
    output_size = output_shape * voxel_size
    micron_scale = voxel_size[0]

    distance_attr = setup_config["DISTANCE_ATTRIBUTE"]
    target_distance = float(setup_config["MIN_DIST_TO_FALLBACK"])
    max_nonempty_points = int(setup_config["MAX_RANDOM_LOCATION_POINTS"])

    mongo_db_template = setup_config["MONGO_DB_TEMPLATE"]
    matched_source = setup_config.get("MATCHED_SOURCE", "matched")

    # New array keys
    # Note: These are intended to be requested with size input_size
    raw = ArrayKey("RAW")
    matched = gp.PointsKey("MATCHED")
    nonempty_placeholder = gp.PointsKey("NONEMPTY")
    labels = ArrayKey("LABELS")

    ensure_nonempty = nonempty_placeholder

    node_offset = {
        sample.name: (daisy.persistence.MongoDbGraphProvider(
            mongo_db_template.format(sample=sample.name,
                                     source="skeletonization"),
            mongo_url,
        ).num_nodes(None) + 1)
        for sample in samples_path.iterdir() if sample.name in source_samples
    }

    # if specified_locations is not None:
    #     centers = pickle.load(open(specified_locations, "rb"))
    #     random = gp.SpecifiedLocation
    #     kwargs = {"locations": centers, "choose_randomly": True}
    #     logger.info(f"Using specified locations from {specified_locations}")
    # elif locations:
    #     random = RandomLocations
    #     kwargs = {
    #         "ensure_nonempty": ensure_nonempty,
    #         "ensure_centered": True,
    #         "point_balance_radius": point_balance_radius * micron_scale,
    #         "loc": gp.ArrayKey("RANDOM_LOCATION"),
    #     }
    # else:

    random = RandomLocation
    kwargs = {
        "ensure_nonempty": ensure_nonempty,
        "ensure_centered": True,
        "point_balance_radius": point_balance_radius * micron_scale,
    }

    data_sources = (tuple(
        (
            gp.ZarrSource(
                filename=str((sample / raw_n5).absolute()),
                datasets={raw: "volume-rechunked"},
                array_specs={
                    raw:
                    gp.ArraySpec(interpolatable=True,
                                 voxel_size=voxel_size,
                                 dtype=np.uint16)
                },
            ),
            DaisyGraphProvider(
                mongo_db_template.format(sample=sample.name,
                                         source=matched_source),
                mongo_url,
                points=[matched],
                directed=True,
                node_attrs=[],
                edge_attrs=[],
            ),
            FilteredDaisyGraphProvider(
                mongo_db_template.format(sample=sample.name,
                                         source=matched_source),
                mongo_url,
                points=[nonempty_placeholder],
                directed=True,
                node_attrs=["distance_to_fallback"],
                edge_attrs=[],
                num_nodes=max_nonempty_points,
                dist_attribute=distance_attr,
                min_dist=target_distance,
            ),
        ) + gp.MergeProvider() + random(**kwargs) + gp.Normalize(raw) +
        FilterComponents(
            matched, node_offset[sample.name], centroid_size=output_size) +
        RasterizeSkeleton(
            points=matched,
            array=labels,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.int64),
        ) for sample in samples_path.iterdir()
        if sample.name in source_samples) + gp.RandomProvider())

    return (data_sources, raw, labels, nonempty_placeholder, matched)
예제 #14
0
파일: train.py 프로젝트: pattonw/mouselight
def train_until(max_iteration):

    # get the latest checkpoint
    if tf.train.latest_checkpoint("."):
        trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    # array keys for fused volume
    raw = gp.ArrayKey("RAW")
    labels = gp.ArrayKey("LABELS")
    labels_fg = gp.ArrayKey("LABELS_FG")

    # array keys for base volume
    raw_base = gp.ArrayKey("RAW_BASE")
    labels_base = gp.ArrayKey("LABELS_BASE")
    swc_base = gp.PointsKey("SWC_BASE")
    swc_center_base = gp.PointsKey("SWC_CENTER_BASE")

    # array keys for add volume
    raw_add = gp.ArrayKey("RAW_ADD")
    labels_add = gp.ArrayKey("LABELS_ADD")
    swc_add = gp.PointsKey("SWC_ADD")
    swc_center_add = gp.PointsKey("SWC_CENTER_ADD")

    # output data
    fg = gp.ArrayKey("FG")
    gradient_fg = gp.ArrayKey("GRADIENT_FG")
    loss_weights = gp.ArrayKey("LOSS_WEIGHTS")

    voxel_size = gp.Coordinate((4, 1, 1))
    input_size = gp.Coordinate(net_config["input_shape"]) * voxel_size
    output_size = gp.Coordinate(net_config["output_shape"]) * voxel_size

    # add request
    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(labels, output_size)
    request.add(labels_fg, output_size)
    request.add(loss_weights, output_size)
    request.add(swc_center_base, output_size)
    request.add(swc_center_add, output_size)

    # add snapshot request
    snapshot_request = gp.BatchRequest()
    snapshot_request.add(fg, output_size)
    snapshot_request.add(labels_fg, output_size)
    snapshot_request.add(gradient_fg, output_size)
    snapshot_request.add(raw_base, input_size)
    snapshot_request.add(raw_add, input_size)
    snapshot_request.add(labels_base, input_size)
    snapshot_request.add(labels_add, input_size)

    # data source for "base" volume
    data_sources_base = tuple(
        (
            gp.Hdf5Source(
                filename,
                datasets={raw_base: "/volume"},
                array_specs={
                    raw_base:
                    gp.ArraySpec(interpolatable=True,
                                 voxel_size=voxel_size,
                                 dtype=np.uint16)
                },
                channels_first=False,
            ),
            SwcSource(
                filename=filename,
                dataset="/reconstruction",
                points=(swc_center_base, swc_base),
                scale=voxel_size,
            ),
        ) + gp.MergeProvider() +
        gp.RandomLocation(ensure_nonempty=swc_center_base) + RasterizeSkeleton(
            points=swc_base,
            array=labels_base,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
            radius=5.0,
        ) for filename in files)

    # data source for "add" volume
    data_sources_add = tuple(
        (
            gp.Hdf5Source(
                file,
                datasets={raw_add: "/volume"},
                array_specs={
                    raw_add:
                    gp.ArraySpec(interpolatable=True,
                                 voxel_size=voxel_size,
                                 dtype=np.uint16)
                },
                channels_first=False,
            ),
            SwcSource(
                filename=file,
                dataset="/reconstruction",
                points=(swc_center_add, swc_add),
                scale=voxel_size,
            ),
        ) + gp.MergeProvider() +
        gp.RandomLocation(ensure_nonempty=swc_center_add) + RasterizeSkeleton(
            points=swc_add,
            array=labels_add,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
            radius=5.0,
        ) for file in files)
    data_sources = (
        (data_sources_base + gp.RandomProvider()),
        (data_sources_add + gp.RandomProvider()),
    ) + gp.MergeProvider()

    pipeline = (
        data_sources + FusionAugment(
            raw_base,
            raw_add,
            labels_base,
            labels_add,
            raw,
            labels,
            blend_mode="labels_mask",
            blend_smoothness=10,
            num_blended_objects=0,
        ) +
        # augment
        gp.ElasticAugment([40, 10, 10], [0.25, 1, 1], [0, math.pi / 2.0],
                          subsample=4) +
        gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2]) +
        gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001) +
        BinarizeGt(labels, labels_fg) +
        gp.BalanceLabels(labels_fg, loss_weights) +
        # train
        gp.PreCache(cache_size=40, num_workers=10) + gp.tensorflow.Train(
            "./train_net",
            optimizer=net_names["optimizer"],
            loss=net_names["loss"],
            inputs={
                net_names["raw"]: raw,
                net_names["labels_fg"]: labels_fg,
                net_names["loss_weights"]: loss_weights,
            },
            outputs={net_names["fg"]: fg},
            gradients={net_names["fg"]: gradient_fg},
            save_every=100000,
        ) +
        # visualize
        gp.Snapshot(
            output_filename="snapshot_{iteration}.hdf",
            dataset_names={
                raw: "volumes/raw",
                raw_base: "volumes/raw_base",
                raw_add: "volumes/raw_add",
                labels: "volumes/labels",
                labels_base: "volumes/labels_base",
                labels_add: "volumes/labels_add",
                fg: "volumes/fg",
                labels_fg: "volumes/labels_fg",
                gradient_fg: "volumes/gradient_fg",
            },
            additional_request=snapshot_request,
            every=100,
        ) + gp.PrintProfilingStats(every=100))

    with gp.build(pipeline):

        print("Starting training...")
        for i in range(max_iteration - trained_until):
            pipeline.request_batch(request)
예제 #15
0
def train(n_iterations, setup_config, mknet_tensor_names, loss_tensor_names):

    # Network hyperparams
    INPUT_SHAPE = setup_config["INPUT_SHAPE"]
    OUTPUT_SHAPE = setup_config["OUTPUT_SHAPE"]

    # Skeleton generation hyperparams
    SKEL_GEN_RADIUS = setup_config["SKEL_GEN_RADIUS"]
    THETAS = np.array(setup_config["THETAS"]) * math.pi
    SPLIT_PS = setup_config["SPLIT_PS"]
    NOISE_VAR = setup_config["NOISE_VAR"]
    N_OBJS = setup_config["N_OBJS"]

    # Skeleton variation hyperparams
    LABEL_RADII = setup_config["LABEL_RADII"]
    RAW_RADII = setup_config["RAW_RADII"]
    RAW_INTENSITIES = setup_config["RAW_INTENSITIES"]

    # Training hyperparams
    CACHE_SIZE = setup_config["CACHE_SIZE"]
    NUM_WORKERS = setup_config["NUM_WORKERS"]
    SNAPSHOT_EVERY = setup_config["SNAPSHOT_EVERY"]
    CHECKPOINT_EVERY = setup_config["CHECKPOINT_EVERY"]

    point_trees = gp.PointsKey("POINT_TREES")
    labels = gp.ArrayKey("LABELS")
    raw = gp.ArrayKey("RAW")
    gt_fg = gp.ArrayKey("GT_FG")
    embedding = gp.ArrayKey("EMBEDDING")
    fg = gp.ArrayKey("FG")
    maxima = gp.ArrayKey("MAXIMA")
    gradient_embedding = gp.ArrayKey("GRADIENT_EMBEDDING")
    gradient_fg = gp.ArrayKey("GRADIENT_FG")

    # tensorflow tensors
    emst = gp.ArrayKey("EMST")
    edges_u = gp.ArrayKey("EDGES_U")
    edges_v = gp.ArrayKey("EDGES_V")
    ratio_pos = gp.ArrayKey("RATIO_POS")
    ratio_neg = gp.ArrayKey("RATIO_NEG")
    dist = gp.ArrayKey("DIST")
    num_pos_pairs = gp.ArrayKey("NUM_POS")
    num_neg_pairs = gp.ArrayKey("NUM_NEG")

    request = gp.BatchRequest()
    request.add(raw, INPUT_SHAPE, voxel_size=gp.Coordinate((1, 1)))
    request.add(labels, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1)))
    request.add(point_trees, INPUT_SHAPE)

    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw, INPUT_SHAPE)
    snapshot_request.add(embedding,
                         OUTPUT_SHAPE,
                         voxel_size=gp.Coordinate((1, 1)))
    snapshot_request.add(fg, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1)))
    snapshot_request.add(gt_fg, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1)))
    snapshot_request.add(maxima,
                         OUTPUT_SHAPE,
                         voxel_size=gp.Coordinate((1, 1)))
    snapshot_request.add(gradient_embedding,
                         OUTPUT_SHAPE,
                         voxel_size=gp.Coordinate((1, 1)))
    snapshot_request.add(gradient_fg,
                         OUTPUT_SHAPE,
                         voxel_size=gp.Coordinate((1, 1)))
    snapshot_request[emst] = gp.ArraySpec()
    snapshot_request[edges_u] = gp.ArraySpec()
    snapshot_request[edges_v] = gp.ArraySpec()
    snapshot_request[ratio_pos] = gp.ArraySpec()
    snapshot_request[ratio_neg] = gp.ArraySpec()
    snapshot_request[dist] = gp.ArraySpec()
    snapshot_request[num_pos_pairs] = gp.ArraySpec()
    snapshot_request[num_neg_pairs] = gp.ArraySpec()

    pipeline = (
        nl.SyntheticLightLike(
            point_trees,
            dims=2,
            r=SKEL_GEN_RADIUS,
            n_obj=N_OBJS,
            thetas=THETAS,
            split_ps=SPLIT_PS,
        )
        # + gp.SimpleAugment()
        # + gp.ElasticAugment([10, 10], [0.1, 0.1], [0, 2.0 * math.pi], spatial_dims=2)
        + nl.RasterizeSkeleton(
            point_trees,
            raw,
            gp.ArraySpec(
                roi=gp.Roi((None, ) * 2, (None, ) * 2),
                voxel_size=gp.Coordinate((1, 1)),
                dtype=np.uint64,
            ),
        ) + nl.RasterizeSkeleton(
            point_trees,
            labels,
            gp.ArraySpec(
                roi=gp.Roi((None, ) * 2, (None, ) * 2),
                voxel_size=gp.Coordinate((1, 1)),
                dtype=np.uint64,
            ),
            use_component=True,
            n_objs=int(setup_config["HIDE_SIGNAL"]),
        ) + nl.GrowLabels(labels, radii=LABEL_RADII) +
        nl.GrowLabels(raw, radii=RAW_RADII) +
        LabelToFloat32(raw, intensities=RAW_INTENSITIES) +
        gp.NoiseAugment(raw, var=NOISE_VAR) +
        gp.PreCache(cache_size=CACHE_SIZE, num_workers=NUM_WORKERS) +
        gp.tensorflow.Train(
            "train_net",
            optimizer=create_custom_loss(mknet_tensor_names, setup_config),
            loss=None,
            inputs={
                mknet_tensor_names["raw"]: raw,
                mknet_tensor_names["gt_labels"]: labels
            },
            outputs={
                mknet_tensor_names["embedding"]: embedding,
                mknet_tensor_names["fg"]: fg,
                "strided_slice_1:0": maxima,
                "gt_fg:0": gt_fg,
                loss_tensor_names["emst"]: emst,
                loss_tensor_names["edges_u"]: edges_u,
                loss_tensor_names["edges_v"]: edges_v,
                loss_tensor_names["ratio_pos"]: ratio_pos,
                loss_tensor_names["ratio_neg"]: ratio_neg,
                loss_tensor_names["dist"]: dist,
                loss_tensor_names["num_pos_pairs"]: num_pos_pairs,
                loss_tensor_names["num_neg_pairs"]: num_neg_pairs,
            },
            gradients={
                mknet_tensor_names["embedding"]: gradient_embedding,
                mknet_tensor_names["fg"]: gradient_fg,
            },
            save_every=CHECKPOINT_EVERY,
            summary="Merge/MergeSummary:0",
            log_dir="tensorflow_logs",
        ) + gp.Snapshot(
            output_filename="{iteration}.hdf",
            dataset_names={
                raw: "volumes/raw",
                labels: "volumes/labels",
                point_trees: "point_trees",
                embedding: "volumes/embedding",
                fg: "volumes/fg",
                maxima: "volumes/maxima",
                gt_fg: "volumes/gt_fg",
                gradient_embedding: "volumes/gradient_embedding",
                gradient_fg: "volumes/gradient_fg",
                emst: "emst",
                edges_u: "edges_u",
                edges_v: "edges_v",
                ratio_pos: "ratio_pos",
                ratio_neg: "ratio_neg",
                dist: "dist",
                num_pos_pairs: "num_pos_pairs",
                num_neg_pairs: "num_neg_pairs",
            },
            dataset_dtypes={
                maxima: np.float32,
                gt_fg: np.float32
            },
            every=SNAPSHOT_EVERY,
            additional_request=snapshot_request,
        )
        # + gp.PrintProfilingStats(every=100)
    )

    with gp.build(pipeline):
        for i in range(n_iterations + 1):
            pipeline.request_batch(request)
            request._update_random_seed()
예제 #16
0
def rasterize_graph(
        graph,
        position_attribute,
        radius_pos,
        radius_tolerance,
        roi,
        voxel_size):
    '''Rasterizes a geometric graph into a numpy array.

    For that, the nodes in the graph are assumed to have a position in 3D (see
    parameter ``position_attribute``).

    The created array will have edges painted with 1, background with 0, and
    (optionally) a tolerance region around each edge with -1.

    Args:

        graph (networkx graph):

            The graph to rasterize. Nodes need to have a position attribute.

        position_attribute (string):

            The name of the position attribute of the nodes. The attribute
            should contain tuples of the form ``(z, y, x)`` in world units.

        radius_pos (float):

            The radius of the lines to draw for each edge in the graph (in
            world units).

        radius_tolerance (float):

            The radius of a region around each edge line that will be labelled
            with ``np.uint64(-1)``. Should be larger than ``radius_pos``. If
            set to ``None``, no such label will be produced.

        roi (gp.Roi):

            The ROI of the area to rasterize.

        voxel_size (tuple of int):

            The size of a voxel in the array to create, in world units.
    '''

    graph_key = gp.PointsKey('GRAPH')
    array = gp.ArrayKey('ARRAY')
    array_spec = gp.ArraySpec(voxel_size=voxel_size, dtype=np.uint64)

    pipeline_pos = (
        NetworkxSource(graph, graph_key) +
        RasterizeSkeleton(graph_key, array, array_spec, radius_pos)
        + GrowLabels(array, tolerance, tolerance_spec, radius_tolerance))

    request = gp.BatchRequest()
    request[array] = gp.ArraySpec(roi=roi)

    with gp.build(pipeline_pos):
        batch = pipeline_pos.request_batch(request)
        return batch[array].data
예제 #17
0
def train_until(max_iteration, return_intermediates=False):

    # get the latest checkpoint
    if tf.train.latest_checkpoint('.'):
        trained_until = int(tf.train.latest_checkpoint('.').split('_')[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    # input data
    ch1 = gp.ArrayKey('CH1')
    ch2 = gp.ArrayKey('CH2')
    swc = gp.PointsKey('SWC')
    swc_env = gp.PointsKey('SWC_ENV')
    swc_center = gp.PointsKey('SWC_CENTER')
    gt = gp.ArrayKey('GT')
    gt_fg = gp.ArrayKey('GT_FG')

    # show fusion augment batches
    if return_intermediates:

        a_ch1 = gp.ArrayKey('A_CH1')
        a_ch2 = gp.ArrayKey('A_CH2')
        b_ch1 = gp.ArrayKey('B_CH1')
        b_ch2 = gp.ArrayKey('B_CH2')
        soft_mask = gp.ArrayKey('SOFT_MASK')

    # output data
    fg = gp.ArrayKey('FG')
    gradient_fg = gp.ArrayKey('GRADIENT_FG')
    loss_weights = gp.ArrayKey('LOSS_WEIGHTS')

    voxel_size = gp.Coordinate((4, 1, 1))
    input_size = gp.Coordinate(net_config['input_shape']) * voxel_size
    output_size = gp.Coordinate(net_config['output_shape']) * voxel_size

    # add request
    request = gp.BatchRequest()
    request.add(ch1, input_size)
    request.add(ch2, input_size)
    request.add(swc, input_size)
    request.add(swc_center, output_size)
    request.add(gt, output_size)
    request.add(gt_fg, output_size)
    # request.add(loss_weights, output_size)

    if return_intermediates:

        request.add(a_ch1, input_size)
        request.add(a_ch2, input_size)
        request.add(b_ch1, input_size)
        request.add(b_ch2, input_size)
        request.add(soft_mask, input_size)

    # add snapshot request
    snapshot_request = gp.BatchRequest()
    # snapshot_request[fg] = request[gt]
    # snapshot_request[gt_fg] = request[gt]
    # snapshot_request[gradient_fg] = request[gt]

    data_sources = tuple()
    data_sources += tuple(
        (Hdf5ChannelSource(file,
                           datasets={
                               ch1: '/volume',
                               ch2: '/volume',
                           },
                           channel_ids={
                               ch1: 0,
                               ch2: 1,
                           },
                           data_format='channels_last',
                           array_specs={
                               ch1:
                               gp.ArraySpec(interpolatable=True,
                                            voxel_size=voxel_size,
                                            dtype=np.uint16),
                               ch2:
                               gp.ArraySpec(interpolatable=True,
                                            voxel_size=voxel_size,
                                            dtype=np.uint16),
                           }),
         SwcSource(filename=file,
                   dataset='/reconstruction',
                   points=(swc_center, swc),
                   return_env=True,
                   scale=voxel_size)) + gp.MergeProvider() +
        gp.RandomLocation(ensure_nonempty=swc_center) + RasterizeSkeleton(
            points=swc,
            array=gt,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
            points_env=swc_env,
            iteration=10) for file in files)

    snapshot_datasets = {}

    if return_intermediates:

        snapshot_datasets = {
            ch1: 'volumes/ch1',
            ch2: 'volumes/ch2',
            a_ch1: 'volumes/a_ch1',
            a_ch2: 'volumes/a_ch2',
            b_ch1: 'volumes/b_ch1',
            b_ch2: 'volumes/b_ch2',
            soft_mask: 'volumes/soft_mask',
            gt: 'volumes/gt',
            fg: 'volumes/fg',
            gt_fg: 'volumes/gt_fg',
            gradient_fg: 'volumes/gradient_fg',
        }

    else:

        snapshot_datasets = {
            ch1: 'volumes/ch1',
            ch2: 'volumes/ch2',
            gt: 'volumes/gt',
            fg: 'volumes/fg',
            gt_fg: 'volumes/gt_fg',
            gradient_fg: 'volumes/gradient_fg',
        }

    pipeline = (
        data_sources +
        #gp.RandomProvider() +
        FusionAugment(ch1,
                      ch2,
                      gt,
                      smoothness=1,
                      return_intermediate=return_intermediates) +

        # augment
        #gp.ElasticAugment(...) +
        #gp.SimpleAugment() +
        gp.Normalize(ch1) + gp.Normalize(ch2) + gp.Normalize(a_ch1) +
        gp.Normalize(a_ch2) + gp.Normalize(b_ch1) + gp.Normalize(b_ch2) +
        gp.IntensityAugment(ch1, 0.9, 1.1, -0.001, 0.001) +
        gp.IntensityAugment(ch2, 0.9, 1.1, -0.001, 0.001) +
        BinarizeGt(gt, gt_fg) +

        # visualize
        gp.Snapshot(output_filename='snapshot_{iteration}.hdf',
                    dataset_names=snapshot_datasets,
                    additional_request=snapshot_request,
                    every=20) + gp.PrintProfilingStats(every=1000))

    with gp.build(pipeline):

        print("Starting training...")
        for i in range(max_iteration - trained_until):
            pipeline.request_batch(request)
예제 #18
0
def train(n_iterations):

    point_trees = gp.PointsKey("POINT_TREES")
    labels = gp.ArrayKey("LABELS")
    raw = gp.ArrayKey("RAW")
    # gt_fg = gp.ArrayKey("GT_FG")
    # embedding = gp.ArrayKey("EMBEDDING")
    # fg = gp.ArrayKey("FG")
    # maxima = gp.ArrayKey("MAXIMA")
    # gradient_embedding = gp.ArrayKey("GRADIENT_EMBEDDING")
    # gradient_fg = gp.ArrayKey("GRADIENT_FG")
    # emst = gp.ArrayKey("EMST")
    # edges_u = gp.ArrayKey("EDGES_U")
    # edges_v = gp.ArrayKey("EDGES_V")

    request = gp.BatchRequest()
    request.add(raw, INPUT_SHAPE, voxel_size=gp.Coordinate((1, 1)))
    request.add(labels, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1)))
    request.add(point_trees, INPUT_SHAPE)

    snapshot_request = gp.BatchRequest()
    snapshot_request.add(raw, INPUT_SHAPE)
    # snapshot_request.add(embedding, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1)))
    # snapshot_request.add(fg, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1)))
    # snapshot_request.add(gt_fg, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1)))
    # snapshot_request.add(maxima, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1)))
    # snapshot_request.add(
    #     gradient_embedding, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1))
    # )
    # snapshot_request.add(gradient_fg, OUTPUT_SHAPE, voxel_size=gp.Coordinate((1, 1)))
    # snapshot_request[emst] = gp.ArraySpec()
    # snapshot_request[edges_u] = gp.ArraySpec()
    # snapshot_request[edges_v] = gp.ArraySpec()

    pipeline = (
        nl.SyntheticLightLike(
            point_trees,
            dims=2,
            r=SKEL_GEN_RADIUS,
            n_obj=N_OBJS,
            thetas=THETAS,
            split_ps=SPLIT_PS,
        )
        # + gp.SimpleAugment()
        # + gp.ElasticAugment([10, 10], [0.1, 0.1], [0, 2.0 * math.pi], spatial_dims=2)
        + nl.RasterizeSkeleton(
            point_trees,
            labels,
            gp.ArraySpec(
                roi=gp.Roi((None,) * 2, (None,) * 2),
                voxel_size=gp.Coordinate((1, 1)),
                dtype=np.uint64,
            ),
        )
        + gp.Copy(labels, raw)
        + nl.GrowLabels(labels, radii=LABEL_RADII)
        + nl.GrowLabels(raw, radii=RAW_RADII)
        + LabelToFloat32(raw, intensities=RAW_INTENSITIES)
        + gp.NoiseAugment(raw, var=NOISE_VAR)
        # + gp.PreCache(cache_size=40, num_workers=10)
        # + gp.tensorflow.Train(
        #     "train_net",
        #     optimizer=add_loss,
        #     loss=None,
        #     inputs={tensor_names["raw"]: raw, tensor_names["gt_labels"]: labels},
        #     outputs={
        #         tensor_names["embedding"]: embedding,
        #         tensor_names["fg"]: fg,
        #         "maxima:0": maxima,
        #         "gt_fg:0": gt_fg,
        #         emst_name: emst,
        #         edges_u_name: edges_u,
        #         edges_v_name: edges_v,
        #     },
        #     gradients={
        #         tensor_names["embedding"]: gradient_embedding,
        #         tensor_names["fg"]: gradient_fg,
        #     },
        # )
        + gp.Snapshot(
            output_filename="{iteration}.hdf",
            dataset_names={
                raw: "volumes/raw",
                labels: "volumes/labels",
                point_trees: "point_trees",
                # embedding: "volumes/embedding",
                # fg: "volumes/fg",
                # maxima: "volumes/maxima",
                # gt_fg: "volumes/gt_fg",
                # gradient_embedding: "volumes/gradient_embedding",
                # gradient_fg: "volumes/gradient_fg",
                # emst: "emst",
                # edges_u: "edges_u",
                # edges_v: "edges_v",
            },
            # dataset_dtypes={maxima: np.float32, gt_fg: np.float32},
            every=100,
            additional_request=snapshot_request,
        )
        + gp.PrintProfilingStats(every=10)
    )

    with gp.build(pipeline):
        for i in range(n_iterations):
            pipeline.request_batch(request)
예제 #19
0
            ))]
        output_spec = copy.deepcopy(input_spec)
        output_spec.roi = output_roi

        output_array = gp.Array(output_data, output_spec)

        batch[self.output_array] = output_array


input_size = Coordinate([74, 260, 260])
output_size = Coordinate([42, 168, 168])
path_to_data = Path("/nrs/funke/mouselight-v2")

# array keys for data sources
raw = gp.ArrayKey("RAW")
swcs = gp.PointsKey("SWCS")
labels = gp.ArrayKey("LABELS")

# array keys for base volume
raw_base = gp.ArrayKey("RAW_BASE")
labels_base = gp.ArrayKey("LABELS_BASE")
swc_base = gp.PointsKey("SWC_BASE")

# array keys for add volume
raw_add = gp.ArrayKey("RAW_ADD")
labels_add = gp.ArrayKey("LABELS_ADD")
swc_add = gp.PointsKey("SWC_ADD")

# array keys for fused volume
raw_fused = gp.ArrayKey("RAW_FUSED")
labels_fused = gp.ArrayKey("LABELS_FUSED")
def train_distance_pipeline(n_iterations, setup_config, mknet_tensor_names,
                            loss_tensor_names):
    input_shape = gp.Coordinate(setup_config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(setup_config["OUTPUT_SHAPE"])
    voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"])
    num_iterations = setup_config["NUM_ITERATIONS"]
    cache_size = setup_config["CACHE_SIZE"]
    num_workers = setup_config["NUM_WORKERS"]
    snapshot_every = setup_config["SNAPSHOT_EVERY"]
    checkpoint_every = setup_config["CHECKPOINT_EVERY"]
    profile_every = setup_config["PROFILE_EVERY"]
    seperate_by = setup_config["SEPERATE_BY"]
    gap_crossing_dist = setup_config["GAP_CROSSING_DIST"]
    match_distance_threshold = setup_config["MATCH_DISTANCE_THRESHOLD"]
    point_balance_radius = setup_config["POINT_BALANCE_RADIUS"]
    max_label_dist = setup_config["MAX_LABEL_DIST"]

    samples_path = Path(setup_config["SAMPLES_PATH"])
    mongo_url = setup_config["MONGO_URL"]

    input_size = input_shape * voxel_size
    output_size = output_shape * voxel_size
    # voxels have size ~= 1 micron on z axis
    # use this value to scale anything that depends on world unit distance
    micron_scale = voxel_size[0]
    seperate_distance = (np.array(seperate_by)).tolist()

    # array keys for data sources
    raw = gp.ArrayKey("RAW")
    consensus = gp.PointsKey("CONSENSUS")
    skeletonization = gp.PointsKey("SKELETONIZATION")
    matched = gp.PointsKey("MATCHED")
    labels = gp.ArrayKey("LABELS")

    dist = gp.ArrayKey("DIST")
    dist_mask = gp.ArrayKey("DIST_MASK")
    dist_cropped = gp.ArrayKey("DIST_CROPPED")
    loss_weights = gp.ArrayKey("LOSS_WEIGHTS")

    # tensorflow tensors
    fg_dist = gp.ArrayKey("FG_DIST")
    gradient_fg = gp.ArrayKey("GRADIENT_FG")

    # add request
    request = gp.BatchRequest()
    request.add(dist_mask, output_size)
    request.add(dist_cropped, output_size)
    request.add(raw, input_size)
    request.add(labels, input_size)
    request.add(dist, input_size)
    request.add(matched, input_size)
    request.add(skeletonization, input_size)
    request.add(consensus, input_size)
    request.add(loss_weights, output_size)

    # add snapshot request
    snapshot_request = gp.BatchRequest()

    # tensorflow requests
    snapshot_request.add(raw, input_size)  # input_size request for positioning
    snapshot_request.add(gradient_fg, output_size, voxel_size=voxel_size)
    snapshot_request.add(fg_dist, output_size, voxel_size=voxel_size)

    data_sources = tuple(
        (
            gp.N5Source(
                filename=str((sample /
                              "fluorescence-near-consensus.n5").absolute()),
                datasets={raw: "volume"},
                array_specs={
                    raw:
                    gp.ArraySpec(interpolatable=True,
                                 voxel_size=voxel_size,
                                 dtype=np.uint16)
                },
            ),
            gp.DaisyGraphProvider(
                f"mouselight-{sample.name}-consensus",
                mongo_url,
                points=[consensus],
                directed=True,
                node_attrs=[],
                edge_attrs=[],
            ),
            gp.DaisyGraphProvider(
                f"mouselight-{sample.name}-skeletonization",
                mongo_url,
                points=[skeletonization],
                directed=False,
                node_attrs=[],
                edge_attrs=[],
            ),
        ) + gp.MergeProvider() + gp.RandomLocation(
            ensure_nonempty=consensus,
            ensure_centered=True,
            point_balance_radius=point_balance_radius * micron_scale,
        ) + TopologicalMatcher(
            skeletonization,
            consensus,
            matched,
            failures=Path("matching_failures_slow"),
            match_distance_threshold=match_distance_threshold * micron_scale,
            max_gap_crossing=gap_crossing_dist * micron_scale,
            try_complete=False,
            use_gurobi=True,
        ) + RejectIfEmpty(matched, center_size=output_size) +
        RasterizeSkeleton(
            points=matched,
            array=labels,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
        ) + gp.contrib.nodes.add_distance.AddDistance(
            labels,
            dist,
            dist_mask,
            max_distance=max_label_dist * micron_scale) + gp.contrib.nodes.
        tanh_saturate.TanhSaturate(dist, scale=micron_scale, offset=1)
        + ThresholdMask(dist, loss_weights, 1e-4)
        # TODO: Do these need to be scaled by world units?
        + gp.ElasticAugment(
            [40, 10, 10],
            [0.25, 1, 1],
            [0, math.pi /
             2.0],
            subsample=4,
            use_fast_points_transform=True,
            recompute_missing_points=False,
        )
        # + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2])
        + gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001)
        for sample in samples_path.iterdir()
        if sample.name in ("2018-07-02", "2018-08-01"))

    pipeline = (
        data_sources + gp.RandomProvider() + Crop(dist, dist_cropped)
        # + gp.PreCache(cache_size=cache_size, num_workers=num_workers)
        + gp.tensorflow.Train(
            "train_net_foreground",
            optimizer=mknet_tensor_names["optimizer"],
            loss=mknet_tensor_names["fg_loss"],
            inputs={
                mknet_tensor_names["raw"]: raw,
                mknet_tensor_names["gt_distances"]: dist_cropped,
                mknet_tensor_names["loss_weights"]: loss_weights,
            },
            outputs={mknet_tensor_names["fg_pred"]: fg_dist},
            gradients={mknet_tensor_names["fg_pred"]: gradient_fg},
            save_every=checkpoint_every,
            # summary=mknet_tensor_names["summaries"],
            log_dir="tensorflow_logs",
        ) + gp.PrintProfilingStats(every=profile_every) + gp.Snapshot(
            additional_request=snapshot_request,
            output_filename="snapshot_{}_{}.hdf".format(
                int(np.min(seperate_distance)), "{id}"),
            dataset_names={
                # raw data
                raw: "volumes/raw",
                labels: "volumes/labels",
                # labeled data
                dist_cropped: "volumes/dist",
                # trees
                skeletonization: "points/skeletonization",
                consensus: "points/consensus",
                matched: "points/matched",
                # output volumes
                fg_dist: "volumes/fg_dist",
                gradient_fg: "volumes/gradient_fg",
                # output debug data
                dist_mask: "volumes/dist_mask",
                loss_weights: "volumes/loss_weights"
            },
            every=snapshot_every,
        ))

    with gp.build(pipeline):
        for _ in range(num_iterations):
            pipeline.request_batch(request)
예제 #21
0
def train_simple_pipeline(n_iterations, setup_config, mknet_tensor_names,
                          loss_tensor_names):
    input_shape = gp.Coordinate(setup_config["INPUT_SHAPE"])
    output_shape = gp.Coordinate(setup_config["OUTPUT_SHAPE"])
    voxel_size = gp.Coordinate(setup_config["VOXEL_SIZE"])
    num_iterations = setup_config["NUM_ITERATIONS"]
    cache_size = setup_config["CACHE_SIZE"]
    num_workers = setup_config["NUM_WORKERS"]
    snapshot_every = setup_config["SNAPSHOT_EVERY"]
    checkpoint_every = setup_config["CHECKPOINT_EVERY"]
    profile_every = setup_config["PROFILE_EVERY"]
    seperate_by = setup_config["SEPERATE_BY"]
    gap_crossing_dist = setup_config["GAP_CROSSING_DIST"]
    match_distance_threshold = setup_config["MATCH_DISTANCE_THRESHOLD"]
    point_balance_radius = setup_config["POINT_BALANCE_RADIUS"]
    neuron_radius = setup_config["NEURON_RADIUS"]

    samples_path = Path(setup_config["SAMPLES_PATH"])
    mongo_url = setup_config["MONGO_URL"]

    input_size = input_shape * voxel_size
    output_size = output_shape * voxel_size
    # voxels have size ~= 1 micron on z axis
    # use this value to scale anything that depends on world unit distance
    micron_scale = voxel_size[0]
    seperate_distance = (np.array(seperate_by)).tolist()

    # array keys for data sources
    raw = gp.ArrayKey("RAW")
    consensus = gp.PointsKey("CONSENSUS")
    skeletonization = gp.PointsKey("SKELETONIZATION")
    matched = gp.PointsKey("MATCHED")
    labels = gp.ArrayKey("LABELS")

    labels_fg = gp.ArrayKey("LABELS_FG")
    labels_fg_bin = gp.ArrayKey("LABELS_FG_BIN")
    loss_weights = gp.ArrayKey("LOSS_WEIGHTS")

    # tensorflow tensors
    gt_fg = gp.ArrayKey("GT_FG")
    fg_pred = gp.ArrayKey("FG_PRED")
    embedding = gp.ArrayKey("EMBEDDING")
    fg = gp.ArrayKey("FG")
    maxima = gp.ArrayKey("MAXIMA")
    gradient_embedding = gp.ArrayKey("GRADIENT_EMBEDDING")
    gradient_fg = gp.ArrayKey("GRADIENT_FG")
    emst = gp.ArrayKey("EMST")
    edges_u = gp.ArrayKey("EDGES_U")
    edges_v = gp.ArrayKey("EDGES_V")
    ratio_pos = gp.ArrayKey("RATIO_POS")
    ratio_neg = gp.ArrayKey("RATIO_NEG")
    dist = gp.ArrayKey("DIST")
    num_pos_pairs = gp.ArrayKey("NUM_POS")
    num_neg_pairs = gp.ArrayKey("NUM_NEG")

    # add request
    request = gp.BatchRequest()
    request.add(labels_fg, output_size)
    request.add(labels_fg_bin, output_size)
    request.add(loss_weights, output_size)
    request.add(raw, input_size)
    request.add(labels, input_size)
    request.add(matched, input_size)
    request.add(skeletonization, input_size)
    request.add(consensus, input_size)

    # add snapshot request
    snapshot_request = gp.BatchRequest()
    request.add(labels_fg, output_size)

    # tensorflow requests
    # snapshot_request.add(raw, input_size)  # input_size request for positioning
    # snapshot_request.add(embedding, output_size, voxel_size=voxel_size)
    # snapshot_request.add(fg, output_size, voxel_size=voxel_size)
    # snapshot_request.add(gt_fg, output_size, voxel_size=voxel_size)
    # snapshot_request.add(fg_pred, output_size, voxel_size=voxel_size)
    # snapshot_request.add(maxima, output_size, voxel_size=voxel_size)
    # snapshot_request.add(gradient_embedding, output_size, voxel_size=voxel_size)
    # snapshot_request.add(gradient_fg, output_size, voxel_size=voxel_size)
    # snapshot_request[emst] = gp.ArraySpec()
    # snapshot_request[edges_u] = gp.ArraySpec()
    # snapshot_request[edges_v] = gp.ArraySpec()
    # snapshot_request[ratio_pos] = gp.ArraySpec()
    # snapshot_request[ratio_neg] = gp.ArraySpec()
    # snapshot_request[dist] = gp.ArraySpec()
    # snapshot_request[num_pos_pairs] = gp.ArraySpec()
    # snapshot_request[num_neg_pairs] = gp.ArraySpec()

    data_sources = tuple(
        (
            gp.N5Source(
                filename=str((sample /
                              "fluorescence-near-consensus.n5").absolute()),
                datasets={raw: "volume"},
                array_specs={
                    raw:
                    gp.ArraySpec(interpolatable=True,
                                 voxel_size=voxel_size,
                                 dtype=np.uint16)
                },
            ),
            gp.DaisyGraphProvider(
                f"mouselight-{sample.name}-consensus",
                mongo_url,
                points=[consensus],
                directed=True,
                node_attrs=[],
                edge_attrs=[],
            ),
            gp.DaisyGraphProvider(
                f"mouselight-{sample.name}-skeletonization",
                mongo_url,
                points=[skeletonization],
                directed=False,
                node_attrs=[],
                edge_attrs=[],
            ),
        ) + gp.MergeProvider() + gp.RandomLocation(
            ensure_nonempty=consensus,
            ensure_centered=True,
            point_balance_radius=point_balance_radius * micron_scale,
        ) + TopologicalMatcher(
            skeletonization,
            consensus,
            matched,
            failures=Path("matching_failures_slow"),
            match_distance_threshold=match_distance_threshold * micron_scale,
            max_gap_crossing=gap_crossing_dist * micron_scale,
            try_complete=False,
            use_gurobi=True,
        ) + RejectIfEmpty(matched) + RasterizeSkeleton(
            points=matched,
            array=labels,
            array_spec=gp.ArraySpec(
                interpolatable=False, voxel_size=voxel_size, dtype=np.uint32),
        ) + GrowLabels(labels, radii=[neuron_radius * micron_scale])
        # TODO: Do these need to be scaled by world units?
        + gp.ElasticAugment(
            [40, 10, 10],
            [0.25, 1, 1],
            [0, math.pi / 2.0],
            subsample=4,
            use_fast_points_transform=True,
            recompute_missing_points=False,
        )
        # + gp.SimpleAugment(mirror_only=[1, 2], transpose_only=[1, 2])
        + gp.Normalize(raw) + gp.IntensityAugment(raw, 0.9, 1.1, -0.001, 0.001)
        for sample in samples_path.iterdir()
        if sample.name in ("2018-07-02", "2018-08-01"))

    pipeline = (
        data_sources + gp.RandomProvider() + Crop(labels, labels_fg) +
        BinarizeGt(labels_fg, labels_fg_bin) +
        gp.BalanceLabels(labels_fg_bin, loss_weights) +
        gp.PreCache(cache_size=cache_size, num_workers=num_workers) +
        gp.tensorflow.Train(
            "train_net",
            optimizer=create_custom_loss(mknet_tensor_names, setup_config),
            loss=None,
            inputs={
                mknet_tensor_names["loss_weights"]: loss_weights,
                mknet_tensor_names["raw"]: raw,
                mknet_tensor_names["gt_labels"]: labels_fg,
            },
            outputs={
                mknet_tensor_names["embedding"]: embedding,
                mknet_tensor_names["fg"]: fg,
                loss_tensor_names["fg_pred"]: fg_pred,
                loss_tensor_names["maxima"]: maxima,
                loss_tensor_names["gt_fg"]: gt_fg,
                loss_tensor_names["emst"]: emst,
                loss_tensor_names["edges_u"]: edges_u,
                loss_tensor_names["edges_v"]: edges_v,
                loss_tensor_names["ratio_pos"]: ratio_pos,
                loss_tensor_names["ratio_neg"]: ratio_neg,
                loss_tensor_names["dist"]: dist,
                loss_tensor_names["num_pos_pairs"]: num_pos_pairs,
                loss_tensor_names["num_neg_pairs"]: num_neg_pairs,
            },
            gradients={
                mknet_tensor_names["embedding"]: gradient_embedding,
                mknet_tensor_names["fg"]: gradient_fg,
            },
            save_every=checkpoint_every,
            summary="Merge/MergeSummary:0",
            log_dir="tensorflow_logs",
        ) + gp.PrintProfilingStats(every=profile_every) + gp.Snapshot(
            additional_request=snapshot_request,
            output_filename="snapshot_{}_{}.hdf".format(
                int(np.min(seperate_distance)), "{id}"),
            dataset_names={
                # raw data
                raw: "volumes/raw",
                # labeled data
                labels: "volumes/labels",
                # trees
                skeletonization: "points/skeletonization",
                consensus: "points/consensus",
                matched: "points/matched",
                # output volumes
                embedding: "volumes/embedding",
                fg: "volumes/fg",
                maxima: "volumes/maxima",
                gt_fg: "volumes/gt_fg",
                fg_pred: "volumes/fg_pred",
                gradient_embedding: "volumes/gradient_embedding",
                gradient_fg: "volumes/gradient_fg",
                # output trees
                emst: "emst",
                edges_u: "edges_u",
                edges_v: "edges_v",
                # output debug data
                ratio_pos: "ratio_pos",
                ratio_neg: "ratio_neg",
                dist: "dist",
                num_pos_pairs: "num_pos_pairs",
                num_neg_pairs: "num_neg_pairs",
                loss_weights: "volumes/loss_weights",
            },
            every=snapshot_every,
        ))

    with gp.build(pipeline):
        for _ in range(num_iterations):
            pipeline.request_batch(request)
예제 #22
0
def train_until(max_iteration):

    # get the latest checkpoint
    if tf.train.latest_checkpoint("."):
        trained_until = int(tf.train.latest_checkpoint(".").split("_")[-1])
    else:
        trained_until = 0
        if trained_until >= max_iteration:
            return

    # array keys for data sources
    raw = gp.ArrayKey("RAW")
    swcs = gp.PointsKey("SWCS")

    voxel_size = gp.Coordinate((10, 3, 3))
    input_size = gp.Coordinate(net_config["input_shape"]) * voxel_size * 2

    # add request
    request = gp.BatchRequest()
    request.add(raw, input_size)
    request.add(swcs, input_size)

    data_sources = tuple((
        gp.N5Source(
            filename=str((
                filename /
                "consensus-neurons-with-machine-centerpoints-labelled-as-swcs-carved.n5"
            ).absolute()),
            datasets={raw: "volume"},
            array_specs={
                raw:
                gp.ArraySpec(interpolatable=True,
                             voxel_size=voxel_size,
                             dtype=np.uint16)
            },
        ),
        MouselightSwcFileSource(
            filename=str((
                filename /
                "consensus-neurons-with-machine-centerpoints-labelled-as-swcs/G-002.swc"
            ).absolute()),
            points=(swcs, ),
            scale=voxel_size,
            transpose=(2, 1, 0),
            transform_file=str((filename / "transform.txt").absolute()),
        ),
    ) + gp.MergeProvider() + gp.RandomLocation(ensure_nonempty=swcs,
                                               ensure_centered=True)
                         for filename in Path(sample_dir).iterdir()
                         if "2018-08-01" in filename.name)

    pipeline = data_sources + gp.RandomProvider()

    with gp.build(pipeline):

        print("Starting training...")
        for i in range(max_iteration - trained_until):
            batch = pipeline.request_batch(request)
            vis_points_with_array(batch[raw].data,
                                  points_to_graph(batch[swcs].data),
                                  np.array(voxel_size))