def test_summarywriter_image(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) expected_img = onp.random.uniform(low=0., high=255., size=(30, 30, 3)) expected_img = expected_img.astype(onp.uint8) summary_writer.image(tag='image_test', image=expected_img, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) self.assertEqual(summary_value.tag, 'image_test') actual_img = tf.image.decode_image(summary_value.tensor.string_val[2]) self.assertTrue(onp.allclose(actual_img, expected_img))
def test_summarywriter_single_channel_image_scaled(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) img = onp.random.uniform(low=0., high=255., size=(30, 30, 1)) img = img.astype(onp.uint8) summary_writer.image(tag='2dimage_1channel_test', image=img, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) self.assertEqual(summary_value.tag, '2dimage_1channel_test') actual_img = tf.image.decode_image(summary_value.tensor.string_val[2]) # assert the image was increased in dimension self.assertEqual(actual_img.shape, (30, 30, 3))
def test_summarywriter_image_float_pixel_values(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) expected_img = onp.random.uniform(low=0., high=1., size=(30, 30, 3)) summary_writer.image(tag='image_test', image=expected_img, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) # convert and scale expected_img appropriately to numpy uint8. expected_img = tf.image.convert_image_dtype( image=expected_img, dtype=onp.uint8) self.assertEqual(summary_value.tag, 'image_test') actual_img = tf.image.decode_image(summary_value.tensor.string_val[2]) self.assertTrue(onp.allclose(actual_img, expected_img))
def test_summarywriter_multiple_2dimages_scaled(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) img = np.random.uniform(low=0., high=255., size=(2, 30, 30)) img = img.astype(np.uint8) summary_writer.image(tag='multiple_2dimages_test', image=img, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) self.assertEqual(summary_value.tag, 'multiple_2dimages_test') actual_imgs = [ tf.image.decode_image(s) for s in summary_value.tensor.string_val[2:] ] # assert the images were increased in dimension self.assertEqual(np.stack(actual_imgs, axis=0).shape, (2, 30, 30, 3))
def test_summarywriter_multiple_images(self): log_dir = tempfile.mkdtemp() summary_writer = SummaryWriter(log_dir=log_dir) expected_img = np.random.uniform(low=0., high=255., size=(2, 30, 30, 3)) expected_img = expected_img.astype(np.uint8) summary_writer.image(tag='multiple_images_test', image=expected_img, step=1) summary_value = self.parse_and_return_summary_value(path=log_dir) self.assertEqual(summary_value.tag, 'multiple_images_test') actual_imgs = [ tf.image.decode_image(s) for s in summary_value.tensor.string_val[2:] ] self.assertTrue( np.allclose(np.stack(actual_imgs, axis=0), expected_img))
def process_batch(*, batch: Dict[str, jnp.ndarray], rng: types.PRNGKey, state: model_utils.TrainState, tag: str, item_id: str, step: int, summary_writer: tensorboard.SummaryWriter, render_fn: Any, save_dir: Optional[gpath.GPath], datasource: datasets.DataSource): """Process and plot a single batch.""" item_id = item_id.replace('/', '_') render = render_fn(state, batch, rng=rng) out = {} if jax.process_index() != 0: return out rgb = render['rgb'] acc = render['acc'] depth_exp = render['depth'] depth_med = render['med_depth'] colorize_depth = functools.partial(viz.colorize, cmin=datasource.near, cmax=datasource.far, invert=True) depth_exp_viz = colorize_depth(depth_exp) depth_med_viz = colorize_depth(depth_med) disp_exp_viz = viz.colorize(1.0 / depth_exp) disp_med_viz = viz.colorize(1.0 / depth_med) acc_viz = viz.colorize(acc, cmin=0.0, cmax=1.0) if save_dir: save_dir.mkdir(parents=True, exist_ok=True) image_utils.save_image(save_dir / f'rgb_{item_id}.png', image_utils.image_to_uint8(rgb)) image_utils.save_image(save_dir / f'depth_expected_viz_{item_id}.png', image_utils.image_to_uint8(depth_exp_viz)) image_utils.save_depth(save_dir / f'depth_expected_{item_id}.png', depth_exp) image_utils.save_image(save_dir / f'depth_median_viz_{item_id}.png', image_utils.image_to_uint8(depth_med_viz)) image_utils.save_depth(save_dir / f'depth_median_{item_id}.png', depth_med) summary_writer.image(f'rgb/{tag}/{item_id}', rgb, step) summary_writer.image(f'depth-expected/{tag}/{item_id}', depth_exp_viz, step) summary_writer.image(f'depth-median/{tag}/{item_id}', depth_med_viz, step) summary_writer.image(f'disparity-expected/{tag}/{item_id}', disp_exp_viz, step) summary_writer.image(f'disparity-median/{tag}/{item_id}', disp_med_viz, step) summary_writer.image(f'acc/{tag}/{item_id}', acc_viz, step) if 'rgb' in batch: rgb_target = batch['rgb'] mse = ((rgb - batch['rgb'])**2).mean() psnr = utils.compute_psnr(mse) ssim = compute_multiscale_ssim(rgb_target, rgb) out['mse'] = mse out['psnr'] = psnr out['ssim'] = ssim logging.info('\tMetrics: mse=%.04f, psnr=%.02f, ssim=%.02f', mse, psnr, ssim) rgb_abs_error = viz.colorize( abs(rgb_target - rgb).sum(axis=-1), cmin=0, cmax=1) rgb_sq_error = viz.colorize( ((rgb_target - rgb)**2).sum(axis=-1), cmin=0, cmax=1) summary_writer.image(f'rgb-target/{tag}/{item_id}', rgb_target, step) summary_writer.image(f'rgb-abs-error/{tag}/{item_id}', rgb_abs_error, step) summary_writer.image(f'rgb-sq-error/{tag}/{item_id}', rgb_sq_error, step) if 'depth' in batch: depth_target = batch['depth'] depth_target_viz = colorize_depth(depth_target[..., 0]) out['depth_abs'] = jnp.nanmean(jnp.abs(depth_target - depth_med)) summary_writer.image( f'depth-target/{tag}/{item_id}', depth_target_viz, step) depth_med_error = viz.colorize( abs(depth_target - depth_med).squeeze(axis=-1), cmin=0, cmax=1) summary_writer.image( f'depth-median-error/{tag}/{item_id}', depth_med_error, step) depth_exp_error = viz.colorize( abs(depth_target - depth_exp).squeeze(axis=-1), cmin=0, cmax=1) summary_writer.image( f'depth-expected-error/{tag}/{item_id}', depth_exp_error, step) return out
def process_batch( *, batch: Dict[str, jnp.ndarray], rng: types.PRNGKey, state: model_utils.TrainState, tag: str, item_id: str, step: int, summary_writer: tensorboard.SummaryWriter, render_fn: Any, save_dir: Optional[gpath.GPath], datasource: datasets.DataSource, ): """Process and plot a single batch.""" rgb, depth_exp, depth_med, acc = render_fn(state, batch, rng=rng) out = {} if jax.host_id() != 0: return out colorize_depth = functools.partial(viz.colorize, cmin=datasource.near, cmax=datasource.far, invert=True) depth_exp_viz = colorize_depth(depth_exp[..., 0]) depth_med_viz = colorize_depth(depth_med[..., 0]) if save_dir: save_dir.mkdir(parents=True, exist_ok=True) image_utils.save_image(save_dir / f"rgb_{item_id}.png", image_utils.image_to_uint8(rgb)) image_utils.save_image( save_dir / f"depth_expected_viz_{item_id}.png", image_utils.image_to_uint8(depth_exp_viz), ) image_utils.save_depth(save_dir / f"depth_expected_{item_id}.png", depth_med[..., 0]) image_utils.save_image( save_dir / f"depth_median_viz_{item_id}.png", image_utils.image_to_uint8(depth_med_viz), ) image_utils.save_depth(save_dir / f"depth_median_{item_id}.png", depth_med[..., 0]) summary_writer.image(f"rgb/{tag}/{item_id}", rgb, step) summary_writer.image(f"depth-expected/{tag}/{item_id}", depth_exp_viz, step) summary_writer.image(f"depth-median/{tag}/{item_id}", depth_med_viz, step) summary_writer.image(f"acc/{tag}/{item_id}", acc, step) if "rgb" in batch: rgb_target = batch["rgb"] mse = ((rgb - batch["rgb"])**2).mean() psnr = utils.compute_psnr(mse) ssim = compute_multiscale_ssim(rgb_target, rgb) out["mse"] = mse out["psnr"] = psnr out["ssim"] = ssim logging.info("\tMetrics: mse=%.04f, psnr=%.02f, ssim=%.02f", mse, psnr, ssim) rgb_abs_error = viz.colorize(abs(rgb_target - rgb).sum(axis=-1), cmin=0, cmax=1) rgb_sq_error = viz.colorize(((rgb_target - rgb)**2).sum(axis=-1), cmin=0, cmax=1) summary_writer.image(f"rgb-target/{tag}/{item_id}", rgb_target, step) summary_writer.image(f"rgb-abs-error/{tag}/{item_id}", rgb_abs_error, step) summary_writer.image(f"rgb-sq-error/{tag}/{item_id}", rgb_sq_error, step) if "depth" in batch: depth_target = batch["depth"] depth_target_viz = colorize_depth(depth_target[..., 0]) out["depth_abs"] = jnp.nanmean(jnp.abs(depth_target - depth_med)) summary_writer.image(f"depth-target/{tag}/{item_id}", depth_target_viz, step) depth_med_error = viz.colorize(abs(depth_target - depth_med).squeeze(axis=-1), cmin=0, cmax=1) summary_writer.image(f"depth-median-error/{tag}/{item_id}", depth_med_error, step) depth_exp_error = viz.colorize(abs(depth_target - depth_exp).squeeze(axis=-1), cmin=0, cmax=1) summary_writer.image(f"depth-expected-error/{tag}/{item_id}", depth_exp_error, step) rel_disp_pred = viz.colorize(1.0 / depth_exp[..., 0]) summary_writer.image(f"relative-disparity/{tag}/{item_id}", rel_disp_pred, step) return out