def add_3dvolume(self, volume, tag, global_step=None, walltime=None): filename = tag + "_" if global_step is None: filename += datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') else: filename += str(global_step) if isinstance(volume, torch.Tensor): volume = volume.detach().cpu().numpy() img = ants.from_numpy(volume) ants.image_write(img, os.path.join(self._log_dir, filename + ".nii.gz")) plugin_data = tf.SummaryMetadata.PluginData( plugin_name="tb_3d_volume_plugin", content=TextPluginData(version=0).SerializeToString()) metadata = tf.SummaryMetadata(plugin_data=plugin_data) tensor = TensorProto( dtype='DT_STRING', string_val=[filename.encode(encoding='utf_8')], tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=1)])) summary = summary_pb2.Summary(value=[ summary_pb2.Summary.Value( tag=tag, metadata=metadata, tensor=tensor) ]) self._file_writer.add_summary(summary, global_step=global_step, walltime=walltime) self._file_writer.flush()
def text(tag, text): plugin_data = SummaryMetadata.PluginData( plugin_name='text', content=TextPluginData(version=0).SerializeToString()) smd = SummaryMetadata(plugin_data=plugin_data) tensor = TensorProto(dtype='DT_STRING', string_val=[text.encode(encoding='utf_8')], tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=1)])) return Summary(value=[Summary.Value(tag=tag + '/text_summary', metadata=smd, tensor=tensor)])