def gather_from_batch(tensor: Tensor, indices: Tensor) -> Tensor:
    """Gather specific indices from a batch of data.

    This method can be useful if you need to compute gradients based on a specific subset of a tensor's output values.
    The `indices` will automatically be cast to the correct type (tf, torch, np) based on the type of the `tensor`.

    This method can be used with Numpy data:
    ```python
    ind = np.array([1, 0, 1])
    n = np.array([[0, 1], [2, 3], [4, 5]])
    b = fe.backend.gather_from_batch(n, ind)  # [1, 2, 5]
    n = np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8, 9], [10, 11]]])
    b = fe.backend.gather_from_batch(n, ind)  # [[2, 3], [4, 5], [10, 11]]
    ```

    This method can be used with TensorFlow tensors:
    ```python
    ind = tf.constant([1, 0, 1])
    t = tf.constant([[0, 1], [2, 3], [4, 5]])
    b = fe.backend.gather_from_batch(t, ind)  # [1, 2, 5]
    t = tf.constant([[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8, 9], [10, 11]]])
    b = fe.backend.gather_from_batch(t, ind)  # [[2, 3], [4, 5], [10, 11]]
    ```

    This method can be used with PyTorch tensors:
    ```python
    ind = torch.tensor([1, 0, 1])
    p = torch.tensor([[0, 1], [2, 3], [4, 5]])
    b = fe.backend.gather_from_batch(p, ind)  # [1, 2, 5]
    p = torch.tensor([[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8, 9], [10, 11]]])
    b = fe.backend.gather_from_batch(p, ind)  # [[2, 3], [4, 5], [10, 11]]
    ```

    Args:
        tensor: A tensor of shape (batch, d1, ..., dn).
        indices: A tensor of shape (batch, ) or (batch, 1) indicating which indices should be selected.

    Returns:
        A tensor of shape (batch, d2, ..., dn) containing the elements from `tensor` at the given `indices`.

    Raises:
        ValueError: If `tensor` is an unacceptable data type.
    """
    if tf.is_tensor(tensor):
        indices = to_tensor(indices, 'tf')
        indices = tf.cast(indices, tf.int64)
        if len(indices.shape) == 1:  # Indices not batched
            indices = expand_dims(indices, 1)
        return tf.gather_nd(tensor, indices=indices, batch_dims=1)
    elif isinstance(tensor, torch.Tensor):
        return tensor[torch.arange(tensor.shape[0]), squeeze(indices)]
    elif isinstance(tensor, np.ndarray):
        return tensor[np.arange(tensor.shape[0]), squeeze(indices)]
    else:
        raise ValueError("Unrecognized tensor type {}".format(type(tensor)))
示例#2
0
def gather(tensor: Tensor, indices: Tensor) -> Tensor:
    """Gather specific indices from a tensor.

    The `indices` will automatically be cast to the correct type (tf, torch, np) based on the type of the `tensor`.

    This method can be used with Numpy data:
    ```python
    ind = np.array([1, 0, 1])
    n = np.array([[0, 1], [2, 3], [4, 5]])
    b = fe.backend.gather(n, ind)  # [[2, 3], [0, 1], [2, 3]]
    n = np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8, 9], [10, 11]]])
    b = fe.backend.gather(n, ind)  # [[[4, 5], [6, 7]], [[0, 1], [2, 3]], [[4, 5], [6, 7]]]
    ```

    This method can be used with TensorFlow tensors:
    ```python
    ind = tf.constant([1, 0, 1])
    t = tf.constant([[0, 1], [2, 3], [4, 5]])
    b = fe.backend.gather(t, ind)  # [[2, 3], [0, 1], [2, 3]]
    t = tf.constant([[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8, 9], [10, 11]]])
    b = fe.backend.gather(t, ind)  # [[[4, 5], [6, 7]], [[0, 1], [2, 3]], [[4, 5], [6, 7]]]
    ```

    This method can be used with PyTorch tensors:
    ```python
    ind = torch.tensor([1, 0, 1])
    p = torch.tensor([[0, 1], [2, 3], [4, 5]])
    b = fe.backend.gather(p, ind)  # [[2, 3], [0, 1], [2, 3]]
    p = torch.tensor([[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[8, 9], [10, 11]]])
    b = fe.backend.gather(p, ind)  # [[[4, 5], [6, 7]], [[0, 1], [2, 3]], [[4, 5], [6, 7]]]
    ```

    Args:
        tensor: A tensor to gather values from.
        indices: A tensor indicating which indices should be selected. These represent locations along the 0 axis.

    Returns:
        A tensor containing the elements from `tensor` at the given `indices`.

    Raises:
        ValueError: If `tensor` is an unacceptable data type.
    """
    if tf.is_tensor(tensor):
        indices = to_tensor(indices, 'tf')
        indices = tf.cast(indices, tf.int64)
        return tf.gather(tensor, indices=squeeze(indices), axis=0)
    elif isinstance(tensor, torch.Tensor):
        return tensor[squeeze(indices).type(torch.int64)]
    elif isinstance(tensor, np.ndarray):
        return np.take(tensor, squeeze(indices).astype('int64'), axis=0)
    else:
        raise ValueError("Unrecognized tensor type {}".format(type(tensor)))
示例#3
0
    def _weight_to_image(
            weight: Tensor,
            kernel_channels_last: bool = False) -> Optional[Tensor]:
        """Logs a weight as a TensorBoard image.

        Implementation from TensorFlow codebase, would have invoked theirs directly but they didn't make it a static
        method.
        """
        w_img = squeeze(weight)
        shape = backend.int_shape(w_img)
        if len(shape) == 1:  # Bias case
            w_img = reshape(w_img, [1, shape[0], 1, 1])
        elif len(shape) == 2:  # Dense layer kernel case
            if shape[0] > shape[1]:
                w_img = permute(w_img, [0, 1])
                shape = backend.int_shape(w_img)
            w_img = reshape(w_img, [1, shape[0], shape[1], 1])
        elif len(shape) == 3:  # ConvNet case
            if kernel_channels_last:
                # Switch to channels_first to display every kernel as a separate images
                w_img = permute(w_img, [2, 0, 1])
            w_img = expand_dims(w_img, axis=-1)
        elif len(shape) == 4:  # Conv filter with multiple input channels
            if kernel_channels_last:
                # Switch to channels first to display kernels as separate images
                w_img = permute(w_img, [3, 2, 0, 1])
            w_img = reduce_sum(
                abs(w_img),
                axis=1)  # Sum over the each channel within the kernel
            w_img = expand_dims(w_img, axis=-1)
        shape = backend.int_shape(w_img)
        # Not possible to handle 3D convnets etc.
        if len(shape) == 4 and shape[-1] in [1, 3, 4]:
            return w_img
示例#4
0
    def on_epoch_end(self, data: Data) -> None:
        mode = self.system.mode
        if self.n_found[mode] > 0:
            if self.n_required[mode] > 0:
                # We are keeping a user-specified number of samples
                self.samples[mode] = {
                    key: concat(val)[:self.n_required[mode]]
                    for key, val in self.samples[mode].items()
                }
            else:
                # We are keeping one batch of data
                self.samples[mode] = {
                    key: val[0]
                    for key, val in self.samples[mode].items()
                }
            # even if you haven't found n_required samples, you're at end of epoch so no point trying to collect more
            self.n_found[mode] = 0
            self.n_required[mode] = 0

        masks = self.salnet.get_masks(self.samples[mode])
        smoothed, integrated, smint = {}, {}, {}
        if self.smoothing:
            smoothed = self.salnet.get_smoothed_masks(self.samples[mode],
                                                      nsamples=self.smoothing)
        if self.integrating:
            if isinstance(self.integrating, Tuple):
                n_integration, n_smoothing = self.integrating
            else:
                n_integration = self.integrating
                n_smoothing = self.smoothing
            integrated = self.salnet.get_integrated_masks(
                self.samples[mode], nsamples=n_integration)
            if n_smoothing:
                smint = self.salnet.get_smoothed_masks(
                    self.samples[mode],
                    nsamples=n_smoothing,
                    nintegration=n_integration)

        # Arrange the outputs
        args = {}
        if self.class_key:
            classes = self.samples[mode][self.class_key]
            if self.label_mapping:
                classes = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(classes))
                ])
            args[self.class_key] = classes
        for key in self.model_outputs:
            classes = masks[key]
            if self.label_mapping:
                classes = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(classes))
                ])
            args[key] = classes
        sal = smint or integrated or smoothed or masks
        for key, val in self.samples[mode].items():
            if key is not self.class_key:
                args[key] = val
                # Create a linear combination of the original image, the saliency mask, and the product of the two in
                # order to highlight regions of importance
                min_val = reduce_min(val)
                diff = reduce_max(val) - min_val
                for outkey in self.outputs:
                    args["{} {}".format(
                        key, outkey)] = (0.3 * (sal[outkey] *
                                                (val - min_val) + min_val) +
                                         0.3 * val + 0.4 * sal[outkey] * diff +
                                         min_val)
        for key in self.outputs:
            args[key] = masks[key]
            if smoothed:
                args["Smoothed {}".format(key)] = smoothed[key]
            if integrated:
                args["Integrated {}".format(key)] = integrated[key]
            if smint:
                args["SmInt {}".format(key)] = smint[key]
        result = ImgData(colormap="inferno", **args)

        data.write_without_log(self.outputs[0], result)
示例#5
0
    def on_epoch_end(self, data: Data) -> None:
        # Keep only the user-specified number of samples
        images = concat(self.images)[:self.n_samples or self.n_found]
        _, height, width = get_image_dims(images)
        grads = to_number(concat(self.grads)[:self.n_samples or self.n_found])
        if tf.is_tensor(images):
            grads = np.moveaxis(grads, source=-1,
                                destination=1)  # grads should be channel first
        args = {}
        labels = None if not self.labels else concat(
            self.labels)[:self.n_samples or self.n_found]
        if labels is not None:
            if len(labels.shape) > 1:
                labels = argmax(labels, axis=-1)
            if self.label_mapping:
                labels = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(labels))
                ])
            args[self.true_label_key] = labels
        preds = None if not self.preds else concat(
            self.preds)[:self.n_samples or self.n_found]
        if preds is not None:
            if len(preds.shape) > 1:
                preds = argmax(preds, axis=-1)
            if self.label_mapping:
                preds = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(preds))
                ])
            args[self.pred_label_key] = preds
        args[self.image_key] = images
        # Clear memory
        self._reset()
        # Make the image
        # TODO: In future maybe allow multiple different grads to have side-by-side comparisons of classes
        components = [np.mean(grads, axis=1)]
        components = [np.maximum(component, 0) for component in components]
        masks = []
        for component_batch in components:
            img_batch = []
            for img in component_batch:
                img = cv2.resize(img, (width, height))
                img = img - np.min(img)
                img = img / np.max(img)
                img = cv2.cvtColor(
                    cv2.applyColorMap(np.uint8(255 * img), cv2.COLORMAP_JET),
                    cv2.COLOR_BGR2RGB)
                img = np.float32(img) / 255
                img_batch.append(img)
            img_batch = np.array(img_batch, dtype=np.float32)
            # Switch to channel first for pytorch
            if isinstance(images, torch.Tensor):
                img_batch = np.moveaxis(img_batch, source=-1, destination=1)
            masks.append(img_batch)

        components = [
            images + mask for mask in masks
        ]  # This seems to work even if the image is 1 channel instead of 3
        components = [image / reduce_max(image) for image in components]

        for elem in components:
            args[self.grad_key] = elem

        result = ImgData(**args)
        data.write_without_log(self.outputs[0], result)
示例#6
0
    def on_epoch_end(self, data: Data) -> None:
        # Keep only the user-specified number of samples
        images = concat(self.images)[:self.n_samples or self.n_found]
        _, height, width = get_image_dims(images)
        activations = to_number(
            concat(self.activations)[:self.n_samples or self.n_found])
        if tf.is_tensor(images):
            activations = np.moveaxis(
                activations, source=-1,
                destination=1)  # Activations should be channel first
        args = {}
        labels = None if not self.labels else concat(
            self.labels)[:self.n_samples or self.n_found]
        if labels is not None:
            if len(labels.shape) > 1:
                labels = argmax(labels, axis=-1)
            if self.label_mapping:
                labels = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(labels))
                ])
            args[self.true_label_key] = labels
        preds = None if not self.preds else concat(
            self.preds)[:self.n_samples or self.n_found]
        if preds is not None:
            if len(preds.shape) > 1:
                preds = argmax(preds, axis=-1)
            if self.label_mapping:
                preds = np.array([
                    self.label_mapping[clazz]
                    for clazz in to_number(squeeze(preds))
                ])
            args[self.pred_label_key] = preds
        args[self.image_key] = images
        # Clear memory
        self._reset()
        # Make the image
        n_components, batch_component_image = self._project_2d(activations)
        components = []  # component x image (batch x image)
        for component_idx in range(n_components):
            batch = []
            for base_image, component_image in zip(images,
                                                   batch_component_image):
                if len(component_image) > component_idx:
                    mask = component_image[component_idx]
                    mask = cv2.resize(mask, (width, height))
                    mask = mask - np.min(mask)
                    mask = mask / np.max(mask)
                    mask = cv2.cvtColor(
                        cv2.applyColorMap(np.uint8(255 * mask),
                                          cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB)
                    mask = np.float32(mask) / 255
                    # switch to channel first for pytorch
                    if isinstance(base_image, torch.Tensor):
                        mask = np.moveaxis(mask, source=-1, destination=1)
                    new_image = base_image + mask
                    new_image = new_image / reduce_max(new_image)
                else:
                    # There's no component for this image, so display an empty image here
                    new_image = np.ones_like(base_image)
                batch.append(new_image)
            components.append(np.array(batch, dtype=np.float32))

        for idx, elem in enumerate(components):
            args[f"Component {idx}"] = elem

        result = ImgData(**args)
        data.write_without_log(self.outputs[0], result)