Beispiel #1
0
    def compute(self, primary_features: Tensor,
                secondary_features: Sequence[Tensor],
                input_image_size: Size2D):
        # Project all secondary features to a common number of channels
        assert len(self.secondary_projections) == len(secondary_features)
        projections = [
            proj(feats) for proj, feats in zip(self.secondary_projections,
                                               secondary_features)
        ]

        # Compute the decoded output size
        decoded_size = input_image_size.scale(factor=1 /
                                              self.config.output_stride,
                                              quantize=tf.math.ceil)

        # Refine using spatial convolutions
        decoded_features = primary_features
        for projected_secondary_feature in projections:
            # Scale all feature maps to the same size
            decoder_inputs = [
                resample(feature_map,
                         size=decoded_size,
                         method=resample.BILINEAR,
                         resampler=legacy_aligned_resampler)
                for feature_map in (decoded_features,
                                    projected_secondary_feature)
            ]
            # Stack and convolve
            decoded_features = self.refiner(
                tf.concat(decoder_inputs, axis=Axis.channel))

        return decoded_features
Beispiel #2
0
    def segment(self, image: Tensor, match_size=True) -> Tensor:
        """
        Semantically segment the given image and return a single channel image
        where pixel values correspond to class indices.

        If match_size is True, the output is resized to match the input dimensions.
        Otherwise, the output is returned at the original sub-subsampled size.
        """
        # Get the per-class label logits
        logits = self.predictor(self.pre_process(image))

        # Reduce to per-pixel labels corresponding to the most likely class
        labels = tf.argmax(logits, axis=Axis.channel)

        # Resize the label map to the input size if requested
        if match_size:
            labels = tf.squeeze(
                resample(labels[..., tf.newaxis],
                         like=image,
                         method=resample.NEAREST_NEIGHBOR))

        return labels
Beispiel #3
0
    def compute(self, inputs: Tensor):
        # Globally pool the encoder features, convolve, then upsample.
        # The reference implementation refers to this as "image pooling" /
        # "adding image level features".
        pooled = tf.reduce_mean(inputs,
                                axis=(Axis.height, Axis.width),
                                keepdims=True)
        # Convolve the global average pooled (1x1) output
        pooled = self.pooling_projection(pooled)
        # Upsample
        pooled = resample(tensor=pooled,
                          like=inputs,
                          method=resample.NEAREST_NEIGHBOR,
                          resampler=legacy_aligned_resampler)

        # Compute the output for each parallel branch
        branch_outputs = [branch(inputs) for branch in self.branches]

        # Concatenate everything together
        output = tf.concat([pooled] + branch_outputs, axis=Axis.channel)

        # Apply the final output projection
        return self.output_projection(output)
Beispiel #4
0
    def pre_process(self, image: Tensor) -> Tensor:
        """
        Pre-process an input image before feeding it to the network.
        """
        # Convert to float
        image = tf.cast(image, tf.float32)

        # Normalize from [0, 255] to [-1, 1]
        image = (image * 2. / 255.) - 1.0

        # Ensure rank 4 tensor
        if image.shape.rank == 3:
            # Inject batch dimension
            image = image[tf.newaxis, ...]
        elif image.shape.rank != 4:
            raise ValueError('Input image must be either rank 3 or 4.')

        # Resize to expected input size
        if self.config.input_size:
            image = resample(image,
                             size=self.config.input_size,
                             method=resample.BILINEAR)

        return image
Beispiel #5
0
def preprocess(image, size=(256, 256)):
    image = tf.cast(image, tf.float32)
    image = (image * 2. / 255.) - 1.0
    image = image[tf.newaxis, ...]
    image = resample(image, size=size, method=resample.BILINEAR)
    return image