Example #1
0
    def add_3dvolume(self, volume, tag, global_step=None, walltime=None):
        filename = tag + "_"
        if global_step is None:
            filename += datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
        else:
            filename += str(global_step)

        if isinstance(volume, torch.Tensor):
            volume = volume.detach().cpu().numpy()

        img = ants.from_numpy(volume)
        ants.image_write(img, os.path.join(self._log_dir,
                                           filename + ".nii.gz"))

        plugin_data = tf.SummaryMetadata.PluginData(
            plugin_name="tb_3d_volume_plugin",
            content=TextPluginData(version=0).SerializeToString())
        metadata = tf.SummaryMetadata(plugin_data=plugin_data)
        tensor = TensorProto(
            dtype='DT_STRING',
            string_val=[filename.encode(encoding='utf_8')],
            tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=1)]))
        summary = summary_pb2.Summary(value=[
            summary_pb2.Summary.Value(
                tag=tag, metadata=metadata, tensor=tensor)
        ])
        self._file_writer.add_summary(summary,
                                      global_step=global_step,
                                      walltime=walltime)
        self._file_writer.flush()
Example #2
0
def text(tag, text):
    plugin_data = SummaryMetadata.PluginData(
        plugin_name='text', content=TextPluginData(version=0).SerializeToString())
    smd = SummaryMetadata(plugin_data=plugin_data)
    tensor = TensorProto(dtype='DT_STRING',
                         string_val=[text.encode(encoding='utf_8')],
                         tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=1)]))
    return Summary(value=[Summary.Value(tag=tag + '/text_summary', metadata=smd, tensor=tensor)])