def plot_detection_image(self, training_context): tile_images_list = [] input = None target = None output = None data_feed = training_context['data_feed'] data = training_context['train_data'] model = training_context['current_model'] output = try_map_args_and_call(model, data, data_feed) target = data['bbox'] input = data[data_feed[model.signature.inputs.key_list[0]]] input_image = self.reverse_image_transform(to_numpy(input)) targetmask = (target[:, 4] > 0.9) input_image1 = input_image.copy() target_boxes = to_numpy(xywh2xyxy(target[targetmask, :])) for box in target_boxes: plot_one_box(box, input_image1, (255, 128, 128), self.labels[box[5:]]) # input_arr=np.asarray(input_arr) tile_images_list.append(input_image1) input_image = self.reverse_image_transform(to_numpy(input)) mask = (output[:, :, 4] > 0.7) if len(output[:, :, 4]) > 0: mask2 = (argmax(softmax(output[:, :, 5:], -1), -1) != 0) mask = (mask.float() + mask2.float() == 2) output = output[mask, :] input_image2 = input_image.copy() output_boxes = to_numpy(xywh2xyxy(output[mask, :])) for box in output_boxes: plot_one_box(box, input_image2, (255, 255, 128), self.labels[np.argmax(box[5:])]) tile_images_list.append(input_image2) fig = tile_rgb_images(*tile_images_list, save_path=os.path.join( self.save_path, self.tile_image_name_prefix), imshow=True) if ctx.enable_tensorboard and ctx.summary_writer is not None: ctx.summary_writer.add_figure( training_context['training_name'] + '/plot/detection_plot', fig, global_step=training_context['steps'], close=True, walltime=time.time()) plt.close()
def plot_tile_image(self, training_context): tile_images_list = [] input = None target = None output = None is_label_mask = False data_feed = training_context['data_feed'] data = training_context['train_data'] model = training_context['current_model'] if model.output_shape[0] > 2: is_label_mask = True # if len(data) >= 3: for data_key in data.key_list: if data_key == data_feed[model.signature.inputs.key_list[0]]: input = data[data_feed[model.signature.inputs.key_list[0]]] training_context['current_model'].eval() output = model(input) training_context['current_model'].train() elif ( 'target' in data_key or 'label' in data_key or 'mask' in data_key ) and not 'output' in data_key and data_key in data_feed.value_list: target = to_numpy(data[data_key]) if 'alpha' not in data: output = np.argmax(to_numpy(output), 1) if is_label_mask: target = label2color(target, self.palette) output = label2color(output, self.palette) else: output = to_numpy(output[:, 1, :, :] * argmax(output, 1)) target = to_numpy(data['alpha']) input_arr = [] input = to_numpy(input) for i in range(len(input)): input_arr.append(self.reverse_image_transform(input[i])) # input_arr=np.asarray(input_arr) tile_images_list.append(input_arr) if is_label_mask: tile_images_list.append(target) tile_images_list.append(output) else: target_arr = np.expand_dims(target, -1) output_arr = np.expand_dims(output, -1) if 'alpha' not in data: target_arr[target_arr > 0] = 1 background = np.ones_like(target_arr) * self.background tile_images_list.append(target_arr * input_arr + (1 - target_arr) * background) output_arr = np.expand_dims(output, -1) if 'alpha' not in data: output_arr[output_arr > 0] = 1 tile_images_list.append(output_arr * input_arr + (1 - output_arr) * background) # if self.tile_image_include_mask: # tile_images_list.append(input*127.5+127.5) tile_rgb_images(*tile_images_list, save_path=os.path.join(self.save_path, self.tile_image_name_prefix), imshow=True)
def plot_tile_image(self, training_context): axis = 1 if get_backend() == 'tensorflow': axis = -1 tile_images_list = [] input = None target = None output = None is_label_mask = False data_feed = training_context['data_feed'] data = training_context['train_data'] model = training_context['current_model'] if model.output_shape[model.filter_index] > 2: is_label_mask = True # if len(data) >= 3: for data_key in data.key_list: if data_key == data_feed[model.signature.inputs.key_list[0]]: input = data[data_feed[model.signature.inputs.key_list[0]]] model.eval() if is_label_mask: output = to_numpy(argmax(model(input), axis=axis)) else: output = to_numpy( expand_dims(cast(argmax(model(input), axis=axis), input.dtype), axis=-1)) model.train() # elif data_key == data_feed[model.signature.outputs.key_list[0]]: # output = data[data_feed[model.signature.outputs.key_list[0]]] # if output.max() < 0: # output = exp(output) elif ( 'target' in data_key or 'label' in data_key or 'mask' in data_key ) and not 'output' in data_key and data_key in data_feed.value_list: target = to_numpy(data[data_key]) output_arr = None if 'alpha' not in data: output_arr = output.copy() if is_label_mask: target = label2color(target, self.palette) output = label2color(output, self.palette) else: if get_backend() == 'tensorflow': output = output[:, :, :, 1:2] * argmax(output, axis) else: output = (output[:, 1:2, :, :] * argmax(output, axis)).transpose(0, 2, 3, 1) target = to_numpy(data['alpha']) input_arr = [] input = to_numpy(input) for i in range(len(input)): input_arr.append(self.reverse_image_transform(input[i])) # input_arr=np.asarray(input_arr) tile_images_list.append(input_arr) if is_label_mask: tile_images_list.append(target) tile_images_list.append(output) else: target_arr = target if len(target.shape) < len(int_shape(input)): if get_backend() == 'tensorflow': target_arr = np.expand_dims(target, -1) else: target_arr = np.expand_dims(target, 1) if 'alpha' not in data: target_arr[target_arr > 0] = 1 background = np.ones_like(target_arr) * self.background tile_images_list.append(target_arr * input_arr + (1 - target_arr) * background) tile_images_list.append(output_arr * input_arr + (1 - output_arr) * background) # if self.tile_image_include_mask: # tile_images_list.append(input*127.5+127.5) fig = tile_rgb_images(*tile_images_list, save_path=os.path.join( self.save_path, self.tile_image_name_prefix), imshow=True) if ctx.enable_tensorboard and ctx.summary_writer is not None: ctx.summary_writer.add_figure( training_context['training_name'] + '/plot/segtile_image', fig, global_step=training_context['steps'], close=True, walltime=time.time()) plt.close()