def plot_detection_image(self, training_context):
        tile_images_list = []
        input = None
        target = None
        output = None

        data_feed = training_context['data_feed']
        data = training_context['train_data']
        model = training_context['current_model']
        output = try_map_args_and_call(model, data, data_feed)
        target = data['bbox']
        input = data[data_feed[model.signature.inputs.key_list[0]]]
        input_image = self.reverse_image_transform(to_numpy(input))
        targetmask = (target[:, 4] > 0.9)
        input_image1 = input_image.copy()
        target_boxes = to_numpy(xywh2xyxy(target[targetmask, :]))
        for box in target_boxes:
            plot_one_box(box, input_image1, (255, 128, 128),
                         self.labels[box[5:]])

        # input_arr=np.asarray(input_arr)
        tile_images_list.append(input_image1)

        input_image = self.reverse_image_transform(to_numpy(input))
        mask = (output[:, :, 4] > 0.7)
        if len(output[:, :, 4]) > 0:
            mask2 = (argmax(softmax(output[:, :, 5:], -1), -1) != 0)
            mask = (mask.float() + mask2.float() == 2)
        output = output[mask, :]
        input_image2 = input_image.copy()
        output_boxes = to_numpy(xywh2xyxy(output[mask, :]))
        for box in output_boxes:
            plot_one_box(box, input_image2, (255, 255, 128),
                         self.labels[np.argmax(box[5:])])

        tile_images_list.append(input_image2)
        fig = tile_rgb_images(*tile_images_list,
                              save_path=os.path.join(
                                  self.save_path, self.tile_image_name_prefix),
                              imshow=True)
        if ctx.enable_tensorboard and ctx.summary_writer is not None:
            ctx.summary_writer.add_figure(
                training_context['training_name'] + '/plot/detection_plot',
                fig,
                global_step=training_context['steps'],
                close=True,
                walltime=time.time())
        plt.close()
Beispiel #2
0
    def plot_tile_image(self, training_context):
        tile_images_list = []
        input = None
        target = None
        output = None
        is_label_mask = False
        data_feed = training_context['data_feed']
        data = training_context['train_data']
        model = training_context['current_model']
        if model.output_shape[0] > 2:
            is_label_mask = True
        # if len(data) >= 3:
        for data_key in data.key_list:
            if data_key == data_feed[model.signature.inputs.key_list[0]]:
                input = data[data_feed[model.signature.inputs.key_list[0]]]

                training_context['current_model'].eval()
                output = model(input)
                training_context['current_model'].train()

            elif (
                    'target' in data_key or 'label' in data_key
                    or 'mask' in data_key
            ) and not 'output' in data_key and data_key in data_feed.value_list:
                target = to_numpy(data[data_key])

        if 'alpha' not in data:
            output = np.argmax(to_numpy(output), 1)
            if is_label_mask:
                target = label2color(target, self.palette)
                output = label2color(output, self.palette)
        else:
            output = to_numpy(output[:, 1, :, :] * argmax(output, 1))
            target = to_numpy(data['alpha'])

        input_arr = []
        input = to_numpy(input)
        for i in range(len(input)):
            input_arr.append(self.reverse_image_transform(input[i]))
        # input_arr=np.asarray(input_arr)
        tile_images_list.append(input_arr)

        if is_label_mask:
            tile_images_list.append(target)
            tile_images_list.append(output)
        else:
            target_arr = np.expand_dims(target, -1)
            output_arr = np.expand_dims(output, -1)
            if 'alpha' not in data:
                target_arr[target_arr > 0] = 1

            background = np.ones_like(target_arr) * self.background

            tile_images_list.append(target_arr * input_arr +
                                    (1 - target_arr) * background)

            output_arr = np.expand_dims(output, -1)
            if 'alpha' not in data:
                output_arr[output_arr > 0] = 1

            tile_images_list.append(output_arr * input_arr +
                                    (1 - output_arr) * background)

        # if self.tile_image_include_mask:
        #     tile_images_list.append(input*127.5+127.5)
        tile_rgb_images(*tile_images_list,
                        save_path=os.path.join(self.save_path,
                                               self.tile_image_name_prefix),
                        imshow=True)
    def plot_tile_image(self, training_context):
        axis = 1
        if get_backend() == 'tensorflow':
            axis = -1

        tile_images_list = []
        input = None
        target = None
        output = None
        is_label_mask = False
        data_feed = training_context['data_feed']
        data = training_context['train_data']
        model = training_context['current_model']
        if model.output_shape[model.filter_index] > 2:
            is_label_mask = True
        # if len(data) >= 3:
        for data_key in data.key_list:
            if data_key == data_feed[model.signature.inputs.key_list[0]]:
                input = data[data_feed[model.signature.inputs.key_list[0]]]
                model.eval()
                if is_label_mask:
                    output = to_numpy(argmax(model(input), axis=axis))
                else:
                    output = to_numpy(
                        expand_dims(cast(argmax(model(input), axis=axis),
                                         input.dtype),
                                    axis=-1))

                model.train()

            # elif data_key == data_feed[model.signature.outputs.key_list[0]]:
            #     output = data[data_feed[model.signature.outputs.key_list[0]]]
            #     if output.max() < 0:
            #         output = exp(output)

            elif (
                    'target' in data_key or 'label' in data_key
                    or 'mask' in data_key
            ) and not 'output' in data_key and data_key in data_feed.value_list:
                target = to_numpy(data[data_key])
        output_arr = None
        if 'alpha' not in data:
            output_arr = output.copy()
            if is_label_mask:
                target = label2color(target, self.palette)
                output = label2color(output, self.palette)
        else:
            if get_backend() == 'tensorflow':
                output = output[:, :, :, 1:2] * argmax(output, axis)
            else:
                output = (output[:, 1:2, :, :] *
                          argmax(output, axis)).transpose(0, 2, 3, 1)
            target = to_numpy(data['alpha'])

        input_arr = []
        input = to_numpy(input)
        for i in range(len(input)):
            input_arr.append(self.reverse_image_transform(input[i]))
        # input_arr=np.asarray(input_arr)
        tile_images_list.append(input_arr)

        if is_label_mask:
            tile_images_list.append(target)
            tile_images_list.append(output)
        else:
            target_arr = target

            if len(target.shape) < len(int_shape(input)):
                if get_backend() == 'tensorflow':
                    target_arr = np.expand_dims(target, -1)
                else:
                    target_arr = np.expand_dims(target, 1)

            if 'alpha' not in data:
                target_arr[target_arr > 0] = 1

            background = np.ones_like(target_arr) * self.background

            tile_images_list.append(target_arr * input_arr +
                                    (1 - target_arr) * background)

            tile_images_list.append(output_arr * input_arr +
                                    (1 - output_arr) * background)

        # if self.tile_image_include_mask:
        #     tile_images_list.append(input*127.5+127.5)
        fig = tile_rgb_images(*tile_images_list,
                              save_path=os.path.join(
                                  self.save_path, self.tile_image_name_prefix),
                              imshow=True)
        if ctx.enable_tensorboard and ctx.summary_writer is not None:
            ctx.summary_writer.add_figure(
                training_context['training_name'] + '/plot/segtile_image',
                fig,
                global_step=training_context['steps'],
                close=True,
                walltime=time.time())
        plt.close()