Ejemplo n.º 1
0
 def create_target(self, gt):
     # zeros
     assert gt.num_channels is None or gt.num_channels == 1, (
         "Cannot create affinities from ground truth with multiple channels.\n"
         f"GT axes: {gt.axes} with {gt.num_channels} channels")
     label_data = gt[gt.roi]
     axes = gt.axes
     if gt.num_channels is not None:
         label_data = label_data[0]
     else:
         axes = ["c"] + axes
     affinities = seg_to_affgraph(label_data,
                                  self.neighborhood).astype(np.float32)
     if self.lsds:
         descriptors = self.extractor(gt.voxel_size).get_descriptors(
             segmentation=label_data,
             voxel_size=gt.voxel_size,
         )
         return NumpyArray.from_np_array(
             np.concatenate([affinities, descriptors],
                            axis=0,
                            dtype=np.float32),
             gt.roi,
             gt.voxel_size,
             axes,
         )
     return NumpyArray.from_np_array(
         affinities,
         gt.roi,
         gt.voxel_size,
         axes,
     )
Ejemplo n.º 2
0
 def next(self):
     batch = next(self._iter)
     self._iter.send(False)
     return (
         NumpyArray.from_gp_array(batch[self._raw_key]),
         NumpyArray.from_gp_array(batch[self._gt_key]),
         NumpyArray.from_gp_array(batch[self._target_key]),
         NumpyArray.from_gp_array(batch[self._weight_key]),
         NumpyArray.from_gp_array(batch[self._mask_key])
         if self._mask_key is not None
         else None,
     )
Ejemplo n.º 3
0
 def create_weight(self, gt, target, mask):
     return NumpyArray.from_np_array(
         np.ones(target.data.shape),
         target.roi,
         target.voxel_size,
         target.axes,
     )
Ejemplo n.º 4
0
 def create_target(self, gt):
     # zeros
     return NumpyArray.from_np_array(
         np.zeros((self.embedding_dims,) + gt.data.shape[-gt.dims :]),
         gt.roi,
         gt.voxel_size,
         ["c"] + gt.axes,
     )
Ejemplo n.º 5
0
 def create_target(self, gt):
     one_hots = self.process(gt.data)
     return NumpyArray.from_np_array(
         one_hots,
         gt.roi,
         gt.voxel_size,
         gt.axes,
     )
Ejemplo n.º 6
0
 def create_target(self, gt):
     distances = self.process(
         gt.data, gt.voxel_size, self.norm, self.dt_scale_factor
     )
     return NumpyArray.from_np_array(
         distances,
         gt.roi,
         gt.voxel_size,
         gt.axes,
     )
Ejemplo n.º 7
0
    def process(self, batch, request):
        output = gp.Batch()

        gt_array = NumpyArray.from_gp_array(batch[self.gt_key])
        target_array = self.predictor.create_target(gt_array)
        mask_array = NumpyArray.from_gp_array(batch[self.mask_key])
        weight_array = self.predictor.create_weight(
            gt_array, target_array, mask=mask_array
        )

        request_spec = request[self.target_key]
        request_spec.voxel_size = gt_array.voxel_size
        output[self.target_key] = gp.Array(target_array[request_spec.roi], request_spec)
        request_spec = request[self.weights_key]
        request_spec.voxel_size = gt_array.voxel_size
        output[self.weights_key] = gp.Array(
            weight_array[request_spec.roi], request_spec
        )
        return output
Ejemplo n.º 8
0
 def create_weight(self, gt, target, mask):
     aff_weights = balance_weights(
         target[target.roi][:self.num_channels - self.num_lsds].astype(
             np.uint8),
         2,
         slab=tuple(1 if c == "c" else -1 for c in target.axes),
         masks=[mask[target.roi]],
     )
     if self.lsds:
         lsd_weights = np.ones((self.num_lsds, ) + aff_weights.shape[1:],
                               dtype=aff_weights.dtype)
         return NumpyArray.from_np_array(
             np.concatenate([aff_weights, lsd_weights], axis=0),
             target.roi,
             target.voxel_size,
             target.axes,
         )
     return NumpyArray.from_np_array(
         aff_weights,
         target.roi,
         target.voxel_size,
         target.axes,
     )
Ejemplo n.º 9
0
 def create_weight(self, gt, target, mask):
     # balance weights independently for each channel
     if self.mask_distances:
         distance_mask = self.create_distance_mask(
             target[target.roi],
             mask[target.roi],
             target.voxel_size,
             self.norm,
             self.dt_scale_factor,
         )
     else:
         distance_mask = np.ones_like(target.data)
     return NumpyArray.from_np_array(
         balance_weights(
             gt[target.roi],
             2,
             slab=tuple(1 if c == "c" else -1 for c in gt.axes),
             masks=[mask[target.roi], distance_mask],
         ),
         gt.roi,
         gt.voxel_size,
         gt.axes,
     )
Ejemplo n.º 10
0
    def iterate(self, num_iterations, model, optimizer, device):
        t_start_fetch = time.time()

        logger.info("Starting iteration!")

        for iteration in range(self.iteration, self.iteration + num_iterations):
            raw, gt, target, weight, mask = self.next()
            logger.debug(
                f"Trainer fetch batch took {time.time() - t_start_fetch} seconds"
            )

            for param in model.parameters():
                param.grad = None

            t_start_prediction = time.time()
            predicted = model.forward(torch.as_tensor(raw[raw.roi]).to(device).float())
            predicted.retain_grad()
            loss = self._loss.compute(
                predicted,
                torch.as_tensor(target[target.roi]).to(device).float(),
                torch.as_tensor(weight[weight.roi]).to(device).float(),
            )
            loss.backward()
            optimizer.step()

            if (
                self.snapshot_iteration is not None
                and iteration % self.snapshot_iteration == 0
            ):
                snapshot_zarr = zarr.open(self.snapshot_container.container, "a")
                snapshot_arrays = {
                    "volumes/raw": raw,
                    "volumes/gt": gt,
                    "volumes/target": target,
                    "volumes/weight": weight,
                    "volumes/prediction": NumpyArray.from_np_array(
                        predicted.detach().cpu().numpy(),
                        target.roi,
                        target.voxel_size,
                        target.axes,
                    ),
                    "volumes/gradients": NumpyArray.from_np_array(
                        predicted.grad.detach().cpu().numpy(),
                        target.roi,
                        target.voxel_size,
                        target.axes,
                    ),
                }
                if mask is not None:
                    snapshot_arrays["volumes/mask"] = mask
                logger.warning(
                    f"Saving Snapshot. Iteration: {iteration}, "
                    f"Loss: {loss.detach().cpu().numpy().item()}!"
                )
                for k, v in snapshot_arrays.items():
                    k = f"{iteration}/{k}"
                    if k not in snapshot_zarr:
                        snapshot_array_identifier = (
                            self.snapshot_container.array_identifier(k)
                        )
                        ZarrArray.create_from_array_identifier(
                            snapshot_array_identifier,
                            v.axes,
                            v.roi,
                            v.num_channels,
                            v.voxel_size,
                            v.dtype if not v.dtype == bool else np.float32,
                        )
                        dataset = snapshot_zarr[k]
                    else:
                        dataset = snapshot_zarr[k]
                    # remove batch dimension. Everything has a batch
                    # and channel dim because of torch.
                    if not v.dtype == bool:
                        data = v[v.roi][0]
                    else:
                        data = v[v.roi][0].astype(np.float32)
                    if v.num_channels is None:
                        # remove channel dimension
                        assert data.shape[0] == 1, (
                            f"Data for array {k} should not have channels but has shape: "
                            f"{v.shape}. The first dimension is channels"
                        )
                        data = data[0]
                    dataset[:] = data
                    dataset.attrs["offset"] = v.roi.offset
                    dataset.attrs["resolution"] = v.voxel_size
                    dataset.attrs["axes"] = v.axes

            logger.debug(
                f"Trainer step took {time.time() - t_start_prediction} seconds"
            )
            self.iteration += 1
            yield TrainingIterationStats(
                loss=loss.item(),
                iteration=iteration,
                time=time.time() - t_start_prediction,
            )
            t_start_fetch = time.time()