def get_loss(self, inputs, outputs): """Computes the loss used for PTN paper (projection + volume loss).""" g_loss = tf.zeros(dtype=tf.float32, shape=[]) if self._params.proj_weight: g_loss += losses.add_volume_proj_loss( inputs, outputs, self._params.step_size, self._params.proj_weight) if self._params.volume_weight: g_loss += losses.add_volume_loss(inputs, outputs, 1, self._params.volume_weight) slim.summaries.add_scalar_summary(g_loss, 'im2vox_loss', prefix='losses') return g_loss