Ejemplo n.º 1
0
    def create_input_tensors_dict(self, batch_size):
        if self.batching_mode == "pair":
            imgs, labels, tags, original_classes = self._create_inputs_for_pair(
                batch_size)
            imgs_raw = unnormalize(imgs)
        elif self.batching_mode == "group":
            imgs, labels, tags, original_images = self._create_inputs_for_group(
                batch_size)
            original_classes = labels
            imgs_raw = original_images
            # imgs_raw = None
        elif self.batching_mode == "eval":
            imgs, imgs_raw, tags, labels = self._create_inputs_for_eval(
                batch_size)
            original_classes = labels
        else:
            raise ValueError("Incorrect batching mode error")

        #summary = tf.get_collection(tf.GraphKeys.SUMMARIES)[-1]
        # self.summaries.append(summary)
        if self.use_summaries:
            summ = tf.summary.image("imgs", unnormalize(imgs))
            self.summaries.append(summ)
        tensors = {"inputs": imgs, "labels": labels,
                   "tags": tags, "original_labels": original_classes}
        if imgs_raw is not None:
            tensors["imgs_raw"] = imgs_raw
        return tensors
Ejemplo n.º 2
0
    def create_input_tensors_dict(self, batch_size):
        self._load_inputfile_lists()
        resize_mode, input_size = self._get_resize_params(self.subset, self.image_size, ResizeMode.Unchanged)
        augmentors, shuffle = self._parse_augmentors_and_shuffle()

        inputfile_tensors = [tf.convert_to_tensor(l, dtype=tf.string) for l in self.inputfile_lists]
        queue = tf.train.slice_input_producer(inputfile_tensors, shuffle=shuffle)

        tensors_dict, summaries = self._read_inputfiles(queue, resize_mode, input_size, augmentors)
        tensors_dict = create_batch_dict(batch_size, tensors_dict)

        if self.use_summaries:
            inputs = tensors_dict["inputs"]
            input_img = unnormalize(inputs[:, :, :, :3])
            # Add clicks to the image so that they can be viewed in tensorboard
            if inputs.get_shape()[-1] > 4:
                [input_img] = tf.py_func(self.add_clicks, [tf.concat([input_img, inputs[:, :, :, 3:4]],
                                         axis=3), 'r'], [tf.float32])
                [input_img] = tf.py_func(self.add_clicks, [tf.concat([input_img, inputs[:, :, :, 4:5]],
                                         axis=3), 'g'], [tf.float32])

            summ0 = tf.summary.image("inputs", input_img)
            summ1 = tf.summary.image("labels", tensors_dict["labels"] * 255)  # will only work for binary segmentation
            summaries = [summ0, summ1]

            # count is incremented after each summary creation. This helps us to keep track of the number of channels
            count = 0

            # Old label would be present either as a single channel, or along with 2 other distance transform channels.
            if inputs.get_shape()[-1] == 4 or inputs.get_shape == 6:
                summ2 = tf.summary.image("old_labels", inputs[:, :, :, 3:4] * 255)
                summaries.append(summ2)
                count += 1
            # Append the distance transforms, if they are available.
            if inputs.get_shape()[-1] > 4:
                # Get negative distance transform from the extra input channels.
                start = 3 + count
                end = start + 1
                start = tf.constant(start)
                end = tf.constant(end)
                summ3 = tf.summary.image(Constants.DT_NEG, inputs[:, :, :, 3:4])

                start = end
                end = start + 1
                summ4 = tf.summary.image(Constants.DT_POS, inputs[:, :, :, 4:5])
                summaries.append(summ3)
                summaries.append(summ4)

        self.summaries += summaries
        return tensors_dict
Ejemplo n.º 3
0
    def create_input_tensors_dict(self, batch_size):
        use_index_img = self.subset != "train"
        tensors = create_tensor_dict(
            unnormalized_img=self.img_placeholder,
            label=self.label_placeholder,
            tag=self.tag_placeholder,
            raw_label=self.label_placeholder,
            old_label=self.old_label_placeholder,
            flow_past=self.flow_into_past_placeholder,
            flow_future=self.flow_into_future_placeholder,
            use_index_img=use_index_img,
            u0=self.u0_placeholder,
            u1=self.u1_placeholder)

        # TODO: need to set shape here?

        resize_mode, input_size = self._get_resize_params(
            self.subset, self.image_size)
        tensors = resize(tensors, resize_mode, input_size)
        if len(input_size) == 3:
            input_size = input_size[1:]
        tensors = self._prepare_augmented_batch(tensors,
                                                batch_size,
                                                image_size=input_size)

        if self.use_summaries:
            inputs = tensors["inputs"]
            summ0 = tf.summary.image("inputs",
                                     unnormalize(inputs[:, :, :, :3]))
            summ1 = tf.summary.image(
                "labels", tensors["labels"] *
                255)  # will only work well for binary segmentation
            self.summaries.append(summ0)
            self.summaries.append(summ1)

            # Old label would be present either as a single channel, or along with 2 other distance transform channels.
            if inputs.get_shape()[-1] == 4 or inputs.get_shape == 6:
                summ2 = tf.summary.image("old_labels",
                                         inputs[:, :, :, 3:4] * 255)
                self.summaries.append(summ2)

        return tensors
Ejemplo n.º 4
0
    def create_input_tensors_dict(self, batch_size):
        use_index_img = self.subset != "train"
        tensors = create_tensor_dict(
            unnormalized_img=self.img_placeholder,
            label=self.label_placeholder,
            tag=self.tag_placeholder,
            raw_label=self.label_placeholder,
            old_label=self.old_label_placeholder,
            flow_past=self.flow_into_past_placeholder,
            flow_future=self.flow_into_future_placeholder,
            use_index_img=use_index_img,
            u0=self.u0_placeholder,
            u1=self.u1_placeholder,
            flow=self.flow_placeholder)

        # TODO: need to set shape here?

        resize_mode, input_size = self._get_resize_params(
            self.subset, self.image_size)
        tensors = resize(tensors, resize_mode, input_size)
        if len(input_size) == 3:
            input_size = input_size[1:]
        tensors = self._prepare_augmented_batch(tensors,
                                                batch_size,
                                                image_size=input_size)

        if self.use_summaries:
            inputs = tensors["inputs"]
            summ0 = tf.summary.image("inputs",
                                     unnormalize(inputs[:, :, :, :3]))
            summ1 = tf.summary.image(
                "labels", tensors["labels"] *
                255)  # will only work well for binary segmentation
            self.summaries.append(summ0)
            self.summaries.append(summ1)

        return tensors
Ejemplo n.º 5
0
    def _create_summaries(self, tensors_dict):
        inputs = tensors_dict["inputs"]
        input_imgs = unnormalize(inputs[:, :, :, :3])
        # count is incremented after each summary creation. This helps us to keep track of the number of channels
        start = 3
        summaries = []
        # Old label would be present either as a single channel, or along with 2 other distance transform channels.
        if inputs.get_shape()[-1] in [4, 6, 8]:
            if Constants.OLD_LABEL_AS_DT in tensors_dict:
                [input_imgs] = tf.py_func(self.add_clicks, [
                    tf.concat([input_imgs, inputs[:, :, :, start:start + 1]],
                              axis=3), 'g'
                ], [tf.float32])
            else:
                [old_label, input_imgs
                 ] = tf.py_func(self.get_masked_image,
                                [input_imgs, inputs[:, :, :, start:start + 1]],
                                [tf.float32, tf.float32])

            summ = tf.summary.image("old_labels", inputs[:, :, :,
                                                         start:start + 1])
            summaries.append(summ)
            start += 1

        # Add clicks to the image so that they can be viewed in tensorboard
        if inputs.get_shape()[-1] > 6:
            [input_imgs] = tf.py_func(self.add_clicks, [
                tf.concat([input_imgs, inputs[:, :, :, start:start + 1]],
                          axis=3), 'r'
            ], [tf.float32])
            [input_imgs] = tf.py_func(self.add_clicks, [
                tf.concat([input_imgs, inputs[:, :, :, start + 1:start + 2]],
                          axis=3), 'r'
            ], [tf.float32])
            [input_imgs] = tf.py_func(self.add_clicks, [
                tf.concat([input_imgs, inputs[:, :, :, start + 2:start + 3]],
                          axis=3), 'g'
            ], [tf.float32])
            [input_imgs] = tf.py_func(self.add_clicks, [
                tf.concat([input_imgs, inputs[:, :, :, start + 3:start + 4]],
                          axis=3), 'g'
            ], [tf.float32])
        elif inputs.get_shape()[-1] > 4:
            [input_imgs] = tf.py_func(self.add_clicks, [
                tf.concat([input_imgs, inputs[:, :, :, start:start + 1]],
                          axis=3), 'r'
            ], [tf.float32])
            [input_imgs] = tf.py_func(self.add_clicks, [
                tf.concat([input_imgs, inputs[:, :, :, start + 1:start + 2]],
                          axis=3), 'g'
            ], [tf.float32])

        # bounding boxes
        if Constants.BBOXES in tensors_dict:
            bboxes = tensors_dict[Constants.BBOXES]
            # permute y1, y2, x1, x2 -> y1, x1, y2, x1
            bboxes = tf.stack([
                bboxes[..., 0], bboxes[..., 2], bboxes[..., 1], bboxes[..., 3]
            ],
                              axis=-1)
            # normalize bboxes to [0..1]
            height = tf.shape(input_imgs)[1]
            width = tf.shape(input_imgs)[2]
            bboxes = tf.cast(bboxes, tf.float32) / tf.cast(
                tf.stack([height, width, height, width], axis=0), tf.float32)
            imgs_with_bboxes = tf.image.draw_bounding_boxes(input_imgs, bboxes)
            summ = tf.summary.image("inputs", imgs_with_bboxes)
            summaries.append(summ)
        else:
            summ = tf.summary.image("inputs", input_imgs)
            summaries.append(summ)

        if not isinstance(tensors_dict["labels"], tuple):
            # will only work well for binary segmentation
            summ = tf.summary.image("labels", tensors_dict["labels"] * 255)
            summaries.append(summ)

        # Append the distance transforms, if they are available.
        if inputs.get_shape()[-1] > 6:
            # Get negative distance transform from the extra input channels.
            summ = tf.summary.image(Constants.DT_NEG, inputs[:, :, :,
                                                             start:start + 1])
            summaries.append(summ)
            summ = tf.summary.image(Constants.DT_NEG,
                                    inputs[:, :, :, start + 1:start + 2])
            summaries.append(summ)
            summ = tf.summary.image(Constants.DT_POS,
                                    inputs[:, :, :, start + 2:start + 3])
            summaries.append(summ)
            summ = tf.summary.image(Constants.DT_POS,
                                    inputs[:, :, :, start + 3:start + 4])
            summaries.append(summ)
        elif inputs.get_shape()[-1] > 4:
            # Get negative distance transform from the extra input channels.
            summ = tf.summary.image(Constants.DT_NEG, inputs[:, :, :,
                                                             start:start + 1])
            summaries.append(summ)
            summ = tf.summary.image(Constants.DT_POS,
                                    inputs[:, :, :, start + 1:start + 2])
            summaries.append(summ)

        return summaries