Beispiel #1
0
class JSONSaver(Saver):
    """Saves an SPN to a JSON file.

    Args:
        path (str): Full path to the file.
        pretty (bool): Use pretty printing.
        sess (Session): Optional. Session used to retrieve parameter values.
                        If ``None``, the default session is used.
    """

    __logger = get_logger()
    __info = __logger.info
    __debug1 = __logger.debug1

    def __init__(self, path, pretty=False):
        super().__init__(path)
        self._pretty = pretty

    @utils.docinherit(Saver)
    def save(self, root, save_param_vals=True, sess=None):
        self.__info("Saving SPN graph rooted in '%s' to file '%s'" %
                    (root, self._path))
        data = serialize_graph(root,
                               save_param_vals=save_param_vals,
                               sess=sess)
        utils.json_dump(self._path, data, pretty=self._pretty)
Beispiel #2
0
class JSONLoader(Loader):
    """Loads an SPN from a JSON file.

    Args:
        path (str): Full path to the file.
        sess (Session): Optional. Session used to assign parameter values.
                        If ``None``, the default session is used.
    """

    __logger = get_logger()
    __info = __logger.info
    __debug1 = __logger.debug1

    def __init__(self, path):
        super().__init__(path)

    @utils.docinherit(Loader)
    def load(self, load_param_vals=True, sess=None):
        self.__info("Loading SPN graph from file '%s'" % self._path)
        data = utils.json_load(self._path)
        root = deserialize_graph(data,
                                 load_param_vals=load_param_vals,
                                 sess=sess,
                                 nodes_by_name=self._nodes_by_name)
        return root
Beispiel #3
0
class Product(OpNode):
    """A node representing a single product in an SPN.

    Args:
        *values (input_like): Inputs providing input values to this node.
            See :meth:`~libspn.Input.as_input` for possible values.
        name (str): Name of the node.
    """

    __logger = get_logger()
    __info = __logger.info

    def __init__(self, *values, name="Product"):
        self._values = []
        super().__init__(inference_type=InferenceType.MARGINAL, name=name)
        self.set_values(*values)

    def serialize(self):
        data = super().serialize()
        data['values'] = [(i.node.name, i.indices) for i in self._values]
        return data

    def deserialize(self, data):
        super().deserialize(data)
        self.set_values()

    def deserialize_inputs(self, data, nodes_by_name):
        super().deserialize_inputs(data, nodes_by_name)
        self._values = tuple(
            Input(nodes_by_name[nn], i) for nn, i in data['values'])

    @property
    @utils.docinherit(OpNode)
    def inputs(self):
        return self._values

    @property
    def values(self):
        """list of Input: List of value inputs."""
        return self._values

    def set_values(self, *values):
        """Set the inputs providing input values to this node. If no arguments
        are given, all existing value inputs get disconnected.

        Args:
            *values (input_like): Inputs providing input values to this node.
                See :meth:`~libspn.Input.as_input` for possible values.
        """
        self._values = self._parse_inputs(*values)

    def add_values(self, *values):
        """Add more inputs providing input values to this node.

        Args:
            *values (input_like): Inputs providing input values to this node.
                See :meth:`~libspn.Input.as_input` for possible values.
        """
        self._values = self._values + self._parse_inputs(*values)

    @property
    def _const_out_size(self):
        return True

    def _compute_out_size(self, *input_out_sizes):
        return 1

    def _compute_scope(self, *value_scopes):
        if not self._values:
            raise StructureError("%s is missing input values." % self)
        value_scopes = self._gather_input_scopes(*value_scopes)
        return [Scope.merge_scopes(chain.from_iterable(value_scopes))]

    def _compute_valid(self, *value_scopes):
        if not self._values:
            raise StructureError("%s is missing input values." % self)
        value_scopes_ = self._gather_input_scopes(*value_scopes)
        # If already invalid, return None
        if any(s is None for s in value_scopes_):
            return None
        # Check product decomposability
        flat_value_scopes = list(chain.from_iterable(value_scopes_))
        for s1, s2 in combinations(flat_value_scopes, 2):
            if s1 & s2:
                self.__info("%s is not decomposable", self)
                return None
        return self._compute_scope(*value_scopes)

    @utils.lru_cache
    def _compute_value_common(self, *value_tensors):
        """Common actions when computing value."""
        # Check inputs
        if not self._values:
            raise StructureError("%s is missing input values." % self)
        # Prepare values
        value_tensors = self._gather_input_tensors(*value_tensors)
        if len(value_tensors) > 1:
            values = tf.concat(values=value_tensors, axis=1)
        else:
            values = value_tensors[0]
        return values

    @utils.lru_cache
    def _compute_log_value(self, *value_tensors):
        values = self._compute_value_common(*value_tensors)

        # Wrap the log value with its custom gradient
        @tf.custom_gradient
        def log_value(*value_tensors):
            # Defines gradient for the log value
            def gradient(gradients):
                scattered_grads = self._compute_log_mpe_path(
                    gradients, *value_tensors)
                return [sg for sg in scattered_grads if sg is not None]

            return tf.reduce_sum(values, 1, keepdims=True), gradient

        if conf.custom_gradient:
            return log_value(*value_tensors)
        else:
            return tf.reduce_sum(values, 1, keepdims=True)

    def _compute_log_mpe_value(self, *value_tensors):
        return self._compute_log_value(*value_tensors)

    @utils.lru_cache
    def _compute_log_mpe_path(self,
                              counts,
                              *value_values,
                              use_unweighted=False,
                              sample=False,
                              sample_prob=None):
        # Check inputs
        if not self._values:
            raise StructureError("%s is missing input values." % self)

        def process_input(v_input, v_value):
            input_size = v_input.get_size(v_value)
            # Tile the counts if input is larger than 1
            return (tf.tile(counts, [1, input_size])
                    if input_size > 1 else counts)

        # For each input, pass counts to all elements selected by indices
        value_counts = [(process_input(v_input, v_value), v_value)
                        for v_input, v_value in zip(self._values, value_values)
                        ]
        # TODO: Scatter to input tensors can be merged with tiling to reduce
        # the amount of operations.
        return self._scatter_to_input_tensors(*value_counts)

    def _compute_log_gradient(self, gradients, *value_values):
        return self._compute_log_mpe_path(gradients, *value_values)

    def disconnect_inputs(self):
        self._values = None
Beispiel #4
0
class GDLearning:
    """Assembles TF operations performing Gradient Descent learning of an SPN.

    Args:
        value_inference_type (InferenceType): The inference type used during the
            upwards pass through the SPN. Ignored if ``mpe_path`` is given.
        learning_rate (float): Learning rate parameter used for updating SPN weights.
        learning_task_type (LearningTaskType): Learning type used while learning.
        learning_method (LearningMethodType): Learning method type, can be either generative
            (LearningMethodType.GENERATIVE) or discriminative (LearningMethodType.DISCRIMINATIVE).
        marginalizing_root (Sum, ParSums, SumsLayer): A sum node without IVs attached to it (or
            IVs with a fixed no-evidence feed). If it is omitted here, the node will constructed
            internally once needed.
        name (str): The name given to this instance of GDLearning.
        l1_regularize_coeff (float or Tensor): The L1 regularization coefficient.
        l2_regularize_coeff (float or Tensor): The L2 regularization coefficient.
    """

    __logger = get_logger()

    def __init__(self,
                 root,
                 value=None,
                 value_inference_type=None,
                 dropconnect_keep_prob=None,
                 learning_task_type=LearningTaskType.SUPERVISED,
                 learning_method=LearningMethodType.DISCRIMINATIVE,
                 marginalizing_root=None,
                 name="GDLearning",
                 l1_regularize_coeff=None,
                 l2_regularize_coeff=None,
                 optimizer=None):

        if learning_task_type == LearningTaskType.UNSUPERVISED and \
                learning_method == LearningMethodType.DISCRIMINATIVE:
            raise ValueError(
                "It is not possible to do unsupervised learning discriminatively."
            )

        self._root = root
        self._marginalizing_root = marginalizing_root
        if self._turn_off_dropconnect(dropconnect_keep_prob,
                                      learning_task_type):
            self._root.set_dropconnect_keep_prob(1.0)
            if self._marginalizing_root is not None:
                self._marginalizing_root.set_dropconnect_keep_prob(1.0)

        if value is not None and isinstance(value, LogValue):
            self._log_value = value
        else:
            if value is not None:
                GDLearning.__logger.warn(
                    "{}: Value instance is ignored since the current implementation does "
                    "not support gradients with non-log inference. Using a LogValue instance "
                    "instead.".format(name))
            self._log_value = LogValue(
                value_inference_type,
                dropconnect_keep_prob=dropconnect_keep_prob)
        self._learning_task_type = learning_task_type
        self._learning_method = learning_method
        self._l1_regularize_coeff = l1_regularize_coeff
        self._l2_regularize_coeff = l2_regularize_coeff
        self._dropconnect_keep_prob = dropconnect_keep_prob
        self._optimizer = optimizer
        self._name = name

    def learn(self, loss=None, optimizer=None, post_gradient_ops=True):
        """Assemble TF operations performing GD learning of the SPN. This includes setting up
        the loss function (with regularization), setting up the optimizer and setting up
        post gradient-update ops.

        loss (Tensor): The operation corresponding to the loss to minimize.
        optimizer (tf.train.Optimizer): A TensorFlow optimizer to use for minimizing the loss.

        Returns:
            A tuple of grouped update Ops and a loss Op.
        """
        if self._learning_task_type == LearningTaskType.SUPERVISED and self._root.ivs is None:
            raise StructureError(
                "{}: the SPN rooted at {} does not have a latent IVs node, so cannot setup "
                "conditional class probabilities.".format(
                    self._name, self._root))

        # If a loss function is not provided, define the loss function based
        # on learning-type and learning-method
        with tf.name_scope("Loss"):
            if loss is None:
                loss = (self.negative_log_likelihood() if self._learning_method
                        == LearningMethodType.GENERATIVE else
                        self.cross_entropy_loss())
            if self._l1_regularize_coeff is not None or self._l2_regularize_coeff is not None:
                loss += self.regularization_loss()

        # Assemble TF ops for optimizing and weights normalization
        optimizer = optimizer if optimizer is not None else self._optimizer
        if optimizer is None:
            raise ValueError("Did not specify GD optimizer")
        with tf.name_scope("ParameterUpdate"):
            minimize = optimizer.minimize(loss=loss)
            if post_gradient_ops:
                return self.post_gradient_update(minimize), loss
            else:
                return minimize, loss

    def post_gradient_update(self, update_op):
        """Constructs post-parameter update ops such as normalization of weights and clipping of
        scale parameters of GaussianLeaf nodes.

        Args:
            update_op (Tensor): A Tensor corresponding to the parameter update.

        Returns:
            An updated operation where the post-processing has been ensured by TensorFlow's control
            flow mechanisms.
        """
        with tf.name_scope("PostGradientUpdate"):

            # After applying gradients to weights, normalize weights
            with tf.control_dependencies([update_op]):
                weight_norm_ops = []

                def fun(node):
                    if node.is_param:
                        weight_norm_ops.append(node.normalize())

                    if isinstance(node, GaussianLeaf
                                  ) and node.learn_distribution_parameters:
                        weight_norm_ops.append(
                            tf.assign(
                                node.scale_variable,
                                tf.maximum(node.scale_variable,
                                           node._min_stddev)))

                with tf.name_scope("WeightNormalization"):
                    traverse_graph(self._root, fun=fun)
            return tf.group(*weight_norm_ops, name="weight_norm")

    def cross_entropy_loss(self,
                           name="CrossEntropy",
                           reduce_fn=tf.reduce_mean,
                           dropconnect_keep_prob=None):
        """Sets up the cross entropy loss, which is equivalent to -log(p(Y|X)).

        Args:
            name (str): Name of the name scope for the Ops defined here
            reduce_fn (Op): An operation that reduces the losses for all samples to a scalar.
            dropconnect_keep_prob (float or Tensor): Keep probability for dropconnect, will
                override the value of GDLearning._dropconnect_keep_prob.

        Returns:
            A Tensor corresponding to the cross-entropy loss.
        """
        dropconnect_keep_prob = dropconnect_keep_prob if dropconnect_keep_prob is None else \
            self._dropconnect_keep_prob
        with tf.name_scope(name):
            log_prob_data_and_labels = LogValue(
                dropconnect_keep_prob=dropconnect_keep_prob).get_value(
                    self._root)
            log_prob_data = self._log_likelihood(
                dropconnect_keep_prob=dropconnect_keep_prob)
            return -reduce_fn(log_prob_data_and_labels - log_prob_data)

    def negative_log_likelihood(self,
                                name="NegativeLogLikelihood",
                                reduce_fn=tf.reduce_mean,
                                dropconnect_keep_prob=None):
        """Returns the maximum (log) likelihood estimate loss function which corresponds to
        -log(p(X)) in the case of unsupervised learning or -log(p(X,Y)) in the case of supservised
        learning.

        Args:
            name (str): The name for the name scope to use
            reduce_fn (Op): An operation that reduces the losses for all samples to a scalar.
            dropconnect_keep_prob (float or Tensor): Keep probability for dropconnect, will
                override the value of GDLearning._dropconnect_keep_prob.
        Returns:
            A Tensor corresponding to the MLE loss
        """
        with tf.name_scope(name):
            if self._learning_task_type == LearningTaskType.UNSUPERVISED:
                if self._root.ivs is not None:
                    likelihood = self._log_likelihood(
                        dropconnect_keep_prob=dropconnect_keep_prob)
                else:
                    likelihood = self._log_value.get_value(self._root)
            elif self._root.ivs is None:
                raise StructureError(
                    "Root should have IVs node when doing supervised learning."
                )
            else:
                likelihood = self._log_value.get_value(self._root)
            return -reduce_fn(likelihood)

    def _log_likelihood(self,
                        learning_task_type=None,
                        dropconnect_keep_prob=None):
        """Computes log(p(X)) by creating a copy of the root node without IVs. Also turns off
        dropconnect at the root if necessary.

        Returns:
            A Tensor of shape [batch, 1] corresponding to the log likelihood of the data.
        """
        marginalizing_root = self._marginalizing_root or Sum(
            *self._root.values, weights=self._root.weights)
        learning_task_type = learning_task_type or self._learning_task_type
        dropconnect_keep_prob = dropconnect_keep_prob or self._dropconnect_keep_prob
        if self._turn_off_dropconnect(dropconnect_keep_prob,
                                      learning_task_type):
            marginalizing_root.set_dropconnect_keep_prob(1.0)
        return self._log_value.get_value(marginalizing_root)

    def regularization_loss(self, name="Regularization"):
        """Adds regularization to the weight nodes. This can be either L1 or L2 or both, depending
        on what is specified at instantiation of GDLearning.

        Returns:
            A Tensor with the total regularization loss.
        """

        with tf.name_scope(name):
            losses = []

            def regularize_node(node):
                if node.is_param:
                    if self._l1_regularize_coeff is not None:
                        losses.append(self._l1_regularize_coeff *
                                      tf.reduce_sum(tf.abs(node.variable)))
                    if self._l2_regularize_coeff is not None:
                        losses.append(self._l2_regularize_coeff *
                                      tf.reduce_sum(tf.square(node.variable)))

            traverse_graph(self._root, fun=regularize_node)
            return tf.add_n(losses)

    @staticmethod
    def _turn_off_dropconnect(dropconnect_keep_prob, learning_task_type):
        """Determines whether to turn off dropconnect for the root node. """
        return dropconnect_keep_prob is not None and \
            (not isinstance(dropconnect_keep_prob, (int, float)) or dropconnect_keep_prob == 1.0) \
            and learning_task_type == LearningTaskType.SUPERVISED

    @property
    def value(self):
        """Value or LogValue: Computed SPN values."""
        return self._log_value
Beispiel #5
0
class MNISTDataset(Dataset):
    """A dataset providing MNIST data with various types of processing applied.

    The data is returned as a tuple of tensors ``(samples, labels)``, where
    ``samples`` has shape ``[batch_size, width*height]`` and contains
    flattened image data, and ``labels`` has shape ``[batch_size, 1]`` and
    contains integer labels representing the digits in the images.

    Args:
        subset (Subset): Determines what data to provide.
        format (ImageFormat): Image format.
        num_epochs (int): Number of epochs of produced data.
        batch_size (int): Size of a single batch.
        shuffle (bool): Shuffle data within each epoch.
        ratio (int): Downsample by the given ratio.
        crop (int): Crop that many border pixels (after downsampling).
        num_threads (int): Number of threads enqueuing the data queue. If
                           larger than ``1``, the performance will be better,
                           but examples might not be in order even if
                           ``shuffle`` is ``False``.
        allow_smaller_final_batch(bool): If ``False``, the last batch will be
                                         omitted if it has less elements than
                                         ``batch_size``.
        classes (list of int): Optional. If specified, only the listed classes
                               will be provided.
        seed (int): Optional. Seed used when shuffling.
    """

    __logger = get_logger()
    __info = __logger.info
    __debug1 = __logger.debug1

    class Subset(Enum):
        """Specifies what data is provided."""

        ALL = 0
        """Provide all data as one dataset combined of training and test samples."""

        TRAIN = 1
        """Provide only training samples"""

        TEST = 2
        """Provide only test samples."""

    def __init__(self, subset, format, num_epochs, batch_size,
                 shuffle, ratio=1, crop=0, num_threads=1,
                 allow_smaller_final_batch=False, classes=None, seed=None):
        self._orig_width = 28
        self._orig_height = 28
        if subset not in MNISTDataset.Subset:
            raise ValueError("Incorrect subset: %s" % subset)
        self._subset = subset
        if format not in {ImageFormat.FLOAT, ImageFormat.INT, ImageFormat.BINARY}:
            raise ValueError("Incorrect format: %s, "
                             "only FLOAT, INT and BINARY are accepted" % format)
        self._format = format
        if not isinstance(ratio, int):
            raise ValueError("ratio must be an integer")
        if ratio not in {1, 2, 4}:
            raise ValueError("ratio must be one of {1, 2, 4}")
        self._ratio = ratio
        self._width = self._orig_width // ratio
        self._height = self._orig_height // ratio
        if not isinstance(crop, int):
            raise ValueError("crop must be an integer")
        if crop < 0 or crop > (self._width // 2) or crop > (self._height // 2):
            raise ValueError("invalid value of crop")
        self._crop = crop
        self._width -= 2 * crop
        self._height -= 2 * crop
        self._num_channels = 1
        super().__init__(num_vars=(self._height * self._width * self._num_channels),
                         num_vals=format.num_vals,
                         num_labels=1,
                         num_epochs=num_epochs, batch_size=batch_size,
                         shuffle=shuffle,
                         # We shuffle the samples in this class
                         # so batch shuffling is not needed
                         shuffle_batch=False, min_after_dequeue=None,
                         num_threads=num_threads,
                         allow_smaller_final_batch=allow_smaller_final_batch,
                         seed=seed)
        if classes is not None:
            if not isinstance(classes, list):
                raise ValueError("classes must be a list")
            try:
                classes = [int(c) for c in classes]
            except ValueError:
                raise ValueError('classes must be convertible to int')
            if not all(i >= 0 and i <= 9 for i in classes):
                raise ValueError("elements of classes must be digits in the "
                                 "interval [0, 9]")
            if len(set(classes)) != len(classes):
                raise ValueError('classes must contain unique elements')
        self._classes = classes
        self._samples = None
        self._labels = None
        # Get data dir
        self._data_dir = os.path.realpath(os.path.join(
            os.getcwd(), os.path.dirname(__file__),
            os.pardir, os.pardir, 'data', 'mnist'))

    @property
    def orig_height(self):
        """int: Height of the original images."""
        return self._orig_height

    @property
    def orig_width(self):
        """int: Width of the original images."""
        return self._orig_width

    @property
    def format(self):
        """Image format."""
        return self._format

    @property
    def ratio(self):
        """int: Original images are downsampled this number of times."""
        return self._ratio

    @property
    def crop(self):
        """int: That many border pixels are cropped."""
        return self._crop

    @property
    def classes(self):
        """list of int: List of classes provided by the dataset."""
        if self._classes is not None:
            return self._classes
        else:
            return list(range(10))

    @property
    def samples(self):
        """array: Array of all data samples."""
        return self._samples

    @property
    def labels(self):
        """array: Array of all data labels."""
        return self._labels

    @property
    def shape(self):
        """Shape of the image data samples."""
        return ImageShape(self._height, self._width, self._num_channels)

    @utils.docinherit(Dataset)
    def generate_data(self):
        self.load_data()
        # Add input producer that serves the loaded samples
        # All data is shuffled independently of the capacity parameter
        producer = tf.train.slice_input_producer([self._samples, self._labels],
                                                 num_epochs=self._num_epochs,
                                                 # Shuffle data
                                                 shuffle=self._shuffle,
                                                 seed=self._seed)
        return producer

    @utils.docinherit(Dataset)
    def process_data(self, data):
        # Everything is processed before entering the producer
        return data

    def load_data(self):
        """Load all data from MNIST data files."""
        # Load data
        if (self._subset == MNISTDataset.Subset.ALL or
                self._subset == MNISTDataset.Subset.TRAIN):
            self.__info("Loading MNIST training data")
            train_x = self._load_images('train-images-idx3-ubyte.gz')
            train_y = self._load_labels('train-labels-idx1-ubyte.gz')

        if (self._subset == MNISTDataset.Subset.ALL or
                self._subset == MNISTDataset.Subset.TEST):
            self.__info("Loading MNIST test data")
            test_x = self._load_images('t10k-images-idx3-ubyte.gz')
            test_y = self._load_labels('t10k-labels-idx1-ubyte.gz')

        # Collect
        if self._subset == MNISTDataset.Subset.TRAIN:
            samples = train_x
            labels = train_y
        elif self._subset == MNISTDataset.Subset.TEST:
            for i in range(test_x.shape[0]):
                test_x
            samples = test_x
            labels = test_y
        elif self._subset == MNISTDataset.Subset.ALL:
            samples = np.concatenate([train_x, test_x])
            labels = np.concatenate([train_y, test_y])

        # Filter classes
        if self._classes is None:
            self._labels = labels
        else:
            self.__debug1("Selecting classes %s" % self._classes)
            chosen = np.in1d(labels, self._classes)
            samples = samples[chosen]
            self._labels = labels[chosen]

        # Process data (input samples are HxW uint8, and well normalized (0-254/255))
        # - convert to float for accuracy, use float32, since that's what scipy wants
        samples = samples.astype(np.float32) / 255.0
        # - downsample
        if self._ratio > 1:
            num_samples = samples.shape[0]
            samples_resized = [None] * num_samples
            for i in range(num_samples):
                samples_resized[i] = scipy.misc.imresize(
                    samples[i], 1.0 / self._ratio,
                    # bicubic looks best after normalization, sharper than bilinear/lanczos
                    interp='bicubic',
                    mode='F')  # Operate on float32 images
            samples = np.array(samples_resized)
        # - crop (samples are float32 HxW)
        if self._crop > 0:
            samples = samples[:, self._crop:-self._crop, self._crop:-self._crop]
        # - flatten
        samples = samples.reshape(samples.shape[0], -1)
        # - normalize (resized image is likely not normalized)
        samples -= np.amin(samples, axis=1, keepdims=True)
        samples /= np.amax(samples, axis=1, keepdims=True)
        # - convert to format (samples are float32 [0, 1] flattened)
        if self._format == ImageFormat.FLOAT:
            # Already float [0,1]. Ensure the dtype is spn float dtype
            self._samples = samples.astype(conf.dtype.as_numpy_dtype())
        elif self._format == ImageFormat.INT:
            self._samples = np.rint(samples * 255.0).astype(np.uint8)
        elif self._format == ImageFormat.BINARY:
            self._samples = (samples > 0.5).astype(np.uint8)

    def _load_images(self, filename):
        """Extract MNIST images from a file. Stolen from TensorFlow.

        Args:
            filename (str): Filename of the labels file to load.

        Returns:
            array: A 3D uint8 numpy array [num, height, width].

        Raises:
            ValueError: If the bytestream does not start with 2051.
        """
        with gzip.GzipFile(os.path.join(self._data_dir, filename)) as bytestream:
            magic = _read32(bytestream)
            if magic != 2051:
                raise ValueError('Invalid magic number %d in MNIST image file: %s' %
                                 (magic, filename))
            num_images = _read32(bytestream)
            rows = _read32(bytestream)
            cols = _read32(bytestream)
            buf = bytestream.read(rows * cols * num_images)
            data = np.frombuffer(buf, dtype=np.uint8)
            data = data.reshape(num_images, rows, cols)
            return data

    def _load_labels(self, filename):
        """Extract MNIST labels from a file. Stolen from TensorFlow.

        Args:
            filename (str): Filename of the labels file to load.

        Returns:
            array: a 2D int numpy array [num, 1].

        Raises:
            ValueError: If the bystream doesn't start with 2049.
        """
        with gzip.GzipFile(os.path.join(self._data_dir, filename)) as bytestream:
            magic = _read32(bytestream)
            if magic != 2049:
                raise ValueError('Invalid magic number %d in MNIST label file: %s' %
                                 (magic, filename))
            num_items = _read32(bytestream)
            buf = bytestream.read(num_items)
            labels = np.frombuffer(buf, dtype=np.uint8)
            return labels.reshape((-1, 1)).astype(np.int)
Beispiel #6
0
class ImageDataWriter(DataWriter):
    """
    Writer writing flattened image data. The image format is selected by the
    extension given as part of the path. The images are normalized to [0, 1]
    before being saved.

    Args:
        path (str): Path to an image file. The path can be parameterized by
                    ``%n``, which will be replaced by the number of the image.
                    It can also be parameterized by ``%l`` which will be replaced
                    by any labels given to :func:`write`.
        shape (ImageShape): Shape of the image data.
        normalize (bool): Normalize data before saving.
        num_digits (int): Minimum number of digits of the image number.
    """

    __logger = get_logger()
    __debug1 = __logger.debug1
    __is_debug1 = __logger.is_debug1

    def __init__(self, path, shape, normalize=False, num_digits=4):
        self._path = os.path.expanduser(path)
        self._shape = shape
        self._normalize = normalize
        self._num_digits = num_digits
        self._image_no = 0

    def write(self, images, labels=None, image_no=None):
        """
        Write image data to file(s). `images` can be either a single image

        Args:
            images (array): A 1D array containing a single image or a 2D array
                            containing multiple images of the given shape.
            labels (value or array): For a single image, a value or a
                single-element 1D or 2D array containing a label for the image.
                For multiple images, a 1D or 2D array containing labels for the
                images.
            image_no (int): Optional. Number of the first image to write.
                            If not given, the number of the last written image
                            is incremented and used.
        """
        if not isinstance(images, np.ndarray):
            raise ValueError("images must be an array")
        if not (np.issubdtype(images.dtype, np.integer)
                or np.issubdtype(images.dtype, np.floating)):
            raise ValueError("images must be of int or float dtype")
        if image_no is not None and not isinstance(image_no, int):
            raise ValueError("image_no must be integer")

        if self.__is_debug1():
            self.__debug1(
                "Batch size:%s dtype:%s max_min:%s min_max:%s" %
                (images.shape[0], images.dtype, np.amax(np.amin(
                    images, axis=1)), np.amin(np.amax(images, axis=1))))

        # Convert 1-image case to multi-image case
        if images.ndim == 1:
            images = images.reshape([1, -1])
            if labels is not None:
                if isinstance(labels, np.ndarray):
                    if labels.ndim == 1:
                        labels = labels.reshape([1, -1])
                    elif labels.ndim != 2:
                        raise ValueError(
                            "labels array must be 1 or 2 dimensional")
                else:
                    labels = np.array([[labels]])
        elif images.ndim == 2:
            if labels is not None:
                if isinstance(labels, np.ndarray):
                    if labels.ndim == 1:
                        labels = labels.reshape([1, -1])
                    elif labels.ndim != 2:
                        raise ValueError(
                            "labels array must be 1 or 2 dimensional")
                else:
                    raise ValueError("labels must be an array")
        else:
            raise ValueError("images array must be 1 or 2 dimensional")
        # Set image number
        if image_no is not None:
            self._image_no = image_no
        # Calculate shape for imsave
        shape = self._shape
        if self._shape[2] == 1:  # imshow wants single channel as MxN
            shape = self._shape[:2]
        # Save all images
        for i in range(images.shape[0]):
            # Generate path
            path = self._path
            if '%n' in path:
                path = path.replace('%n',
                                    ("%0" + str(self._num_digits) + "d") %
                                    (self._image_no))
            if labels is not None and '%l' in path:
                label = labels[i, 0]
                if isinstance(label, bytes):
                    label = label.decode("utf-8")
                else:
                    label = str(label)
                path = path.replace('%l', label)
            # Normalize?
            # imsave normalizes float images, but not uint8 images
            if self._normalize:
                if np.issubdtype(images.dtype, np.integer):
                    images = images.astype(np.float32)
            else:
                if np.issubdtype(images.dtype, np.floating):
                    images *= 255.0
                images = images.astype(np.uint8)  # Convert also int32/64 to 8
            # Save
            scipy.misc.imsave(path, images[i].reshape(shape))
            self._image_no += 1
class ConvProductsDepthwise(ConvProducts):
    """A container representing convolutional products in an SPN.

    Args:
        *values (input_like): Inputs providing input values to this container.
            See :meth:`~libspn.Input.as_input` for possible values.
        num_channels (int): Number of channels modeled by this node. This parameter is optional.
            If ``None``, the layer will attempt to generate all possible permutations of channels
            under a patch as long as it is under ``num_channels_max``.
        padding (str): Type of padding used. Can be either, 'full', 'valid' or 'wicker_top'.
            For building Wicker CSPNs, 'full' padding is necessary in all but the very last
            ConvProducts node. The last ConvProducts node should take the 'wicker_top' padding algorithm
        dilation_rate (int or tuple of ints): Dilation rate of the convolution.
        strides (int or tuple of ints): Strides used for the convolution.
        spatial_dim_sizes (list or tuple of ints): Dim sizes of spatial dimensions (height and width)
        num_channels_max (int): The maximum number of channels when automatically generating
            permutations.
        name (str): Name of the container.

    Attributes:
        inference_type(InferenceType): Flag indicating the preferred inference
                                       type for this container that will be used
                                       during value calculation and learning.
                                       Can be changed at any time and will be
                                       used during the next inference/learning
                                       op generation.
    """

    logger = get_logger()

    def __init__(self, *values, padding='valid', dilation_rate=1,
                 strides=2, kernel_size=2, inference_type=InferenceType.MARGINAL,
                 name="ConvProductsDepthwise", spatial_dim_sizes=None):
        super().__init__(
            *values, inference_type=inference_type, name=name, spatial_dim_sizes=spatial_dim_sizes,
            strides=strides, kernel_size=kernel_size, padding=padding, dilation_rate=dilation_rate)
        self._num_channels = self._num_input_channels()

    @utils.lru_cache
    def _compute_log_value(self, *input_tensors):
        # Concatenate along channel axis
        concat_inp = self._prepare_convolutional_processing(*input_tensors)

        # This the quickest workaround for TensorFlow's apparent optimization whenever
        # part of the kernel computation involves a -inf:
        concat_inp = tf.where(
            tf.is_inf(concat_inp), tf.fill(tf.shape(concat_inp), value=-1e20), concat_inp)
        # Convolve
        conv_out = tf.nn.conv2d(
            input=self._channels_to_batch(concat_inp),
            filter=tf.ones(self._kernel_size + [1, 1]),
            padding='VALID',
            strides=[1] + self._strides + [1],
            dilations=[1] + self._dilation_rate + [1],
            data_format='NHWC'
        )
        conv_out = self._batch_to_channels(conv_out)
        return self._flatten(conv_out)

    @utils.lru_cache
    def _channels_to_batch(self, t):
        gd = t.shape.as_list()[1:3]
        return tf.reshape(self._transpose_channel_last_to_first(t), [-1] + gd + [1])

    @utils.lru_cache
    def _batch_to_channels(self, t):
        gd = t.shape.as_list()[1:3]
        return self._transpose_channel_first_to_last(tf.reshape(t, [-1, self._num_channels] + gd))

    def _compute_mpe_path_common(self, counts, *input_values):
        if not self._values:
            raise StructureError("{} is missing input values.".format(self))
        # Concatenate inputs along channel axis, should already be done during forward pass
        inp_concat = self._prepare_convolutional_processing(*input_values)
        spatial_counts = tf.reshape(counts, (-1,) + self.output_shape_spatial)

        inp_concat = self._channels_to_batch(inp_concat)
        spatial_counts = self._channels_to_batch(spatial_counts)

        input_counts = tf.nn.conv2d_backprop_input(
            input_sizes=tf.shape(inp_concat),
            filter=tf.ones(self._kernel_size + [1, 1]),
            out_backprop=spatial_counts,
            strides=[1] + self._strides + [1],
            padding='VALID',
            dilations=[1] + self._dilation_rate + [1],
            data_format="NHWC")

        input_counts = self._batch_to_channels(input_counts)

        # In case we have explicitly padded the tensor before forward convolution, we should
        # slice the counts now
        pad_left, pad_right, pad_top, pad_bottom = self.pad_sizes()
        if not any([pad_left, pad_right, pad_top, pad_bottom]):
            return self._split_to_children(input_counts)
        return self._split_to_children(input_counts[:, pad_top:-pad_bottom, pad_left:-pad_right, :])
Beispiel #8
0
class DiscreteDenseModel(Model):
    """Basic dense SPN model operating on discrete data.

    If `num_classes` is greater than 1, a multi-class model is created by
    generating multiple parallel dense models (one for each class) and combining
    them with a sum node with an explicit latent class variable.

    Args:
        num_vars (int): Number of discrete random variables representing data
                        samples.
        num_vals (int): Number of values of each random variable.
        num_classes (int): Number of classes assumed by the model.
        num_decomps (int): Number of decompositions at each level of dense SPN.
        num_subsets (int): Number of variable sub-sets for each decomposition.
        num_mixtures (int): Number of mixtures (sums) for each variable subset.
        input_dist (InputDist): Determines how IVs of the discrete variables for
                                data samples are connected to the model.
        num_input_mixtures (int): Number of mixtures used representing each
                                  discrete data variable (mixing the data variable
                                  IVs) when ``input_dist`` is set to ``MIXTURE``.
                                  If set to ``None``, ``num_mixtures`` is used.
        weight_init_value: Initial value of the weights.
    """

    __logger = get_logger()
    __info = __logger.info
    __debug1 = __logger.debug1
    __is_debug1 = __logger.is_debug1
    __debug2 = __logger.debug2
    __is_debug2 = __logger.is_debug2

    def __init__(self,
                 num_classes,
                 num_decomps,
                 num_subsets,
                 num_mixtures,
                 input_dist=DenseSPNGenerator.InputDist.MIXTURE,
                 num_input_mixtures=None,
                 weight_initializer=tf.initializers.random_uniform(0.0, 1.0)):
        super().__init__()
        if not isinstance(num_classes, int):
            raise ValueError("num_classes must be an integer")
        self._num_classes = num_classes
        self._num_decomps = num_decomps
        self._num_subsets = num_subsets
        self._num_mixtures = num_mixtures
        self._input_dist = input_dist
        self._num_input_mixtures = num_input_mixtures
        self._weight_initializer = weight_initializer
        self._class_ivs = None
        self._sample_ivs = None
        self._class_input = None
        self._sample_inputs = None

    def __repr__(self):
        return (type(self).__qualname__ + "(" +
                ("num_classes=" + str(self._num_classes)) + ", " +
                ("num_decomps=" + str(self._num_decomps)) + ", " +
                ("num_subsets=" + str(self._num_subsets)) + ", " +
                ("num_mixtures=" + str(self._num_mixtures)) + ", " +
                ("input_dist=" + str(self._input_dist)) + ", " +
                ("num_input_mixtures=" + str(self._num_input_mixtures)) +
                ", " + ("weight_init_value=" + str(self._weight_initializer)) +
                ")")

    @utils.docinherit(Model)
    def serialize(self, save_param_vals=True, sess=None):
        # Serialize the graph first
        data = serialize_graph(self._root,
                               save_param_vals=save_param_vals,
                               sess=sess)
        # Add model specific information
        # Inputs
        if self._sample_ivs is not None:
            data['sample_ivs'] = self._sample_ivs.name
        data['sample_inputs'] = [(i.node.name, i.indices)
                                 for i in self._sample_inputs]
        if self._class_ivs is not None:
            data['class_ivs'] = self._class_ivs.name
        if self._class_input:
            data['class_input'] = (self._class_input.node.name,
                                   self._class_input.indices)
        # Model params
        data['num_classes'] = self._num_classes
        data['num_decomps'] = self._num_decomps
        data['num_subsets'] = self._num_subsets
        data['num_mixtures'] = self._num_mixtures
        data['input_dist'] = self._input_dist
        data['num_input_mixtures'] = self._num_input_mixtures
        data['weight_init_value'] = self._weight_initializer
        return data

    @utils.docinherit(Model)
    def deserialize(self, data, load_param_vals=True, sess=None):
        # Deserialize the graph first
        nodes_by_name = {}
        self._root = deserialize_graph(data,
                                       load_param_vals=load_param_vals,
                                       sess=sess,
                                       nodes_by_name=nodes_by_name)
        # Model specific information
        # Inputs
        sample_ivs = data.get('sample_ivs', None)
        if sample_ivs:
            self._sample_ivs = nodes_by_name[sample_ivs]
        else:
            self._sample_ivs = None
        self._sample_inputs = tuple(
            Input(nodes_by_name[nn], i) for nn, i in data['sample_inputs'])
        class_ivs = data.get('class_ivs', None)
        if class_ivs:
            self._class_ivs = nodes_by_name[class_ivs]
        else:
            self._class_ivs = None
        class_input = data.get('class_input', None)
        if class_input:
            self._class_input = Input(nodes_by_name[class_input[0]],
                                      class_input[1])
        else:
            self._class_input = None
        # Model params
        self._num_classes = data['num_classes']
        self._num_decomps = data['num_decomps']
        self._num_subsets = data['num_subsets']
        self._num_mixtures = data['num_mixtures']
        self._input_dist = data['input_dist']
        self._num_input_mixtures = data['num_input_mixtures']
        self._weight_initializer = data['weight_init_value']

    @property
    def sample_ivs(self):
        """IVs: IVs with input data sample."""
        return self._sample_ivs

    @property
    def class_ivs(self):
        """IVs: Class indicator variables."""
        return self._class_ivs

    @property
    def sample_inputs(self):
        """list of Input: Inputs to the model providing data samples."""
        return self._sample_inputs

    @property
    def class_input(self):
        """Input: Input providing class indicators.."""
        return self._class_input

    def build(self,
              *sample_inputs,
              class_input=None,
              num_vars=None,
              num_vals=None,
              seed=None):
        """Build the SPN graph of the model.

        The model can be built on top of any ``sample_inputs``. Otherwise, if no
        sample inputs are provided, the model will internally crate a single IVs
        node to represent the input data samples. In such case, ``num_vars`` and
        ``num_vals`` must be specified.

        Similarly, if ``class_input`` is provided, it is used as a source of
        class indicators of the root sum node combining sub-SPNs modeling
        particular classes. Otherwise, an internal IVs node is created for this
        purpose.

        Args:
            *sample_inputs (input_like): Optional. Inputs to the model
                                         providing data samples.
            class_input (input_like): Optional. Input providing class indicators.
            num_vars (int): Optional. Number of variables in each sample. Must
                            only be provided if ``sample_inputs`` are not given.
            num_vals (int or list of int): Optional. Number of values of each
                variable. Can be a single value or a list of values, one for
                each of ``num_vars`` variables. Must only be provided if
                ``sample_inputs`` are not given.
            seed (int): Optional. Seed used for the dense SPN generator.

        Returns:
           Sum: Root node of the generated model.
        """
        if not sample_inputs:
            if num_vars is None:
                raise ValueError(
                    "num_vars must be given when sample_inputs are not")
            if num_vals is None:
                raise ValueError(
                    "num_vals must be given when sample_inputs are not")
            if not isinstance(num_vars, int) or num_vars < 1:
                raise ValueError("num_vars must be an integer > 0")
            if not isinstance(num_vals, int) or num_vals < 1:
                raise ValueError("num_vals must be an integer > 0")
        if self._num_classes > 1:
            self.__info("Building a discrete dense model with %d classes" %
                        self._num_classes)
        else:
            self.__info("Building a 1-class discrete dense model")

        # Create IVs if inputs not given
        if not sample_inputs:
            self._sample_ivs = IVs(num_vars=num_vars,
                                   num_vals=num_vals,
                                   name="SampleIVs")
            self._sample_inputs = [Input(self._sample_ivs)]
        else:
            self._sample_inputs = tuple(
                Input.as_input(i) for i in sample_inputs)
        if self._num_classes > 1:
            if class_input is None:
                self._class_ivs = IVs(num_vars=1,
                                      num_vals=self._num_classes,
                                      name="ClassIVs")
                self._class_input = Input(self._class_ivs)
            else:
                self._class_input = Input.as_input(class_input)

        # Generate structure
        dense_gen = DenseSPNGenerator(
            num_decomps=self._num_decomps,
            num_subsets=self._num_subsets,
            num_mixtures=self._num_mixtures,
            input_dist=self._input_dist,
            num_input_mixtures=self._num_input_mixtures,
            balanced=True)
        rnd = random.Random(seed)
        if self._num_classes == 1:
            # One-class
            self._root = dense_gen.generate(*self._sample_inputs,
                                            rnd=rnd,
                                            root_name='Root')
        else:
            # Multi-class: create sub-SPNs
            sub_spns = []
            for c in range(self._num_classes):
                rnd_copy = random.Random()
                rnd_copy.setstate(rnd.getstate())
                with tf.name_scope("Class%d" % c):
                    sub_root = dense_gen.generate(*self._sample_inputs,
                                                  rnd=rnd_copy)
                if self.__is_debug1():
                    self.__debug1("sub-SPN %d has %d nodes" %
                                  (c, sub_root.get_num_nodes()))
                sub_spns.append(sub_root)
            # Create root
            self._root = Sum(*sub_spns, ivs=self._class_input, name="Root")

        if self.__is_debug1():
            self.__debug1("SPN graph has %d nodes" %
                          self._root.get_num_nodes())

        # Generate weight nodes
        self.__debug1("Generating weight nodes")
        generate_weights(self._root, initializer=self._weight_initializer)
        if self.__is_debug1():
            self.__debug1(
                "SPN graph has %d nodes and %d TF ops" %
                (self._root.get_num_nodes(), self._root.get_tf_graph_size()))

        return self._root
Beispiel #9
0
class ProductsLayer(OpNode):
    """A node representing all products in a layer in an SPN.

    Args:
        *values (input_like): Inputs providing input values to this node.
            See :meth:`~libspn.Input.as_input` for possible values.
        num_or_size_prods (int or list of int):
            Int: Number of product ops modelled by this node. In which case, all
            the products modelled will have a common size.
            List: Size of each product op modelled by this node. Number of
            products modelled would be the length of the list.
        name (str): Name of the node.
    """

    logger = get_logger()
    info = logger.info

    def __init__(self, *values, num_or_size_prods=1, name="ProductsLayer"):
        self._values = []
        super().__init__(inference_type=InferenceType.MARGINAL, name=name)
        self.set_values(*values)
        self.set_prod_sizes(num_or_size_prods)

    def set_prod_sizes(self, num_or_size_prods):
        # Total size of value input_size
        total_values_size = sum(
            len(v.indices) if v and v.indices else v.node.get_out_size() if v else 0
            for v in self._values)

        if isinstance(num_or_size_prods, int):  # Total number of prodcut ops to be modelled
            if not num_or_size_prods > 0:
                raise StructureError("In %s 'num_or_size_prods': %s need to be > 0"
                                     % self, num_or_size_prods)
            self._num_prods = num_or_size_prods
            self._prod_input_sizes = [total_values_size // self._num_prods] * self._num_prods
        elif isinstance(num_or_size_prods, list):  # Size of each modelled product op
            if not len(num_or_size_prods) > 0:
                raise StructureError("In %s 'num_or_size_prods': %s cannot be an empty list"
                                     % self, num_or_size_prods)
            self._prod_input_sizes = num_or_size_prods
            self._num_prods = len(num_or_size_prods)

        self._num_or_size_prods = num_or_size_prods

    def serialize(self):
        data = super().serialize()
        data['values'] = [(i.node.name, i.indices) for i in self._values]
        data['num_prods'] = self._num_prods
        data['prod_input_sizes'] = self._prod_input_sizes
        data['num_or_size_prods'] = self._num_or_size_prods
        return data

    def deserialize(self, data):
        super().deserialize(data)
        self.set_values()
        self._num_prods = data['num_prods']
        self._prod_input_sizes = data['prod_input_sizes']
        self._num_or_size_prods = data['num_or_size_prods']

    def deserialize_inputs(self, data, nodes_by_name):
        super().deserialize_inputs(data, nodes_by_name)
        self._values = tuple(Input(nodes_by_name[nn], i)
                             for nn, i in data['values'])

    @property
    @utils.docinherit(OpNode)
    def inputs(self):
        return self._values

    @property
    def num_prods(self):
        """int: Number of Product ops modelled by this node."""
        return self._num_prods

    def set_num_prods(self, num_prods=1):
        """Set the number of Product ops modelled by this node.

        Args:
            num_prods (int): Number of Product ops modelled by this node.
        """
        self._num_prods = num_prods

    @property
    def num_or_size_prods(self):
        """int: Number of Product ops modelled by this node."""
        return self._num_or_size_prods

    def set_num_or_size_prods(self, num_or_size_prods=1):
        """Set the number of Product ops modelled by this node.

        Args:
            num_prods (int): Number of Product ops modelled by this node.
        """
        self._num_or_size_prods = num_or_size_prods

    @property
    def values(self):
        """list of Input: List of value inputs."""
        return self._values

    def set_values(self, *values):
        """Set the inputs providing input values to this node. If no arguments
        are given, all existing value inputs get disconnected.

        Args:
            *values (input_like): Inputs providing input values to this node.
                See :meth:`~libspn.Input.as_input` for possible values.
        """
        self._values = self._parse_inputs(*values)

    def add_values(self, *values):
        """Add more inputs providing input values to this node.

        Args:
            *values (input_like): Inputs providing input values to this node.
                See :meth:`~libspn.Input.as_input` for possible values.
        """
        self._values = self._values + self._parse_inputs(*values)

    @property
    def _const_out_size(self):
        return True

    @utils.lru_cache
    def _compute_out_size(self, *input_out_sizes):
        return self._num_prods

    def _compute_scope(self, *value_scopes):
        if not self._values:
            raise StructureError("%s is missing input values." % self)
        # Gather and flatten value scopes
        flat_value_scopes = list(chain.from_iterable(self._gather_input_scopes(
                                                *value_scopes)))
        # Divide gathered and flattened value scopes into sublists, one per
        # modeled product op.
        prod_input_sizes = np.cumsum(np.array(self._prod_input_sizes)).tolist()
        prod_input_sizes.insert(0, 0)
        value_scopes_lists = [flat_value_scopes[start:stop] for start, stop in
                              zip(prod_input_sizes[:-1], prod_input_sizes[1:])]
        return [Scope.merge_scopes(vsl) for vsl in value_scopes_lists]

    def _compute_valid(self, *value_scopes):
        if not self._values:
            raise StructureError("%s is missing input values." % self)
        value_scopes_ = self._gather_input_scopes(*value_scopes)
        # If already invalid, return None
        if any(s is None for s in value_scopes_):
            return None
        # Check product decomposability
        flat_value_scopes = list(chain.from_iterable(value_scopes_))
        # Divide gathered and flattened value scopes into sublists, one per
        # modeled product op.
        prod_input_sizes = np.cumsum(np.array(self._prod_input_sizes)).tolist()
        prod_input_sizes.insert(0, 0)
        value_scopes_lists = [flat_value_scopes[start:stop] for start, stop in
                              zip(prod_input_sizes[:-1], prod_input_sizes[1:])]
        for scopes in value_scopes_lists:
            for s1, s2 in combinations(scopes, 2):
                if s1 & s2:
                    ProductsLayer.info("%s is not decomposable", self)
                    return None
        return self._compute_scope(*value_scopes)

    def _combine_values_and_indices(self, value_tensors):
        """
        Concatenates input tensors and returns the nested indices that are
        required for gathering all product inputs to a reducible set of columns.
        """
        # Chose list instead of dict to maintain order
        unique_tensors = []
        unique_offsets = []

        combined_indices = []
        flat_col_indices = []
        flat_tensor_offsets = []

        tensor_offset = 0
        for value_inp, value_tensor in zip(self._values, value_tensors):
            # Get indices. If not there, will be [0, 1, ... , len-1]
            indices = value_inp.indices if value_inp.indices else \
                np.arange(value_inp.node.get_out_size()).tolist()
            flat_col_indices.append(indices)
            if value_tensor not in unique_tensors:
                # Add the tensor and offsets ot unique
                unique_tensors.append(value_tensor)
                unique_offsets.append(tensor_offset)
                # Add offsets
                flat_tensor_offsets.append([tensor_offset for _ in indices])
                tensor_offset += value_tensor.shape[1].value
            else:
                # Find offset from list
                offset = unique_offsets[unique_tensors.index(value_tensor)]
                # After this, no need to update tensor_offset, since the current value_tensor will
                # wasn't added to unique
                flat_tensor_offsets.append([offset for _ in indices])

        # Flatten the tensor offsets and column indices
        flat_tensor_offsets = np.asarray(list(chain(*flat_tensor_offsets)))
        flat_col_indices = np.asarray(list(chain(*flat_col_indices)))

        # Offset in flattened arrays
        offset = 0
        for size in self._prod_input_sizes:
            # Now indices can be found by adding up column indices and tensor offsets
            indices = flat_col_indices[offset:offset + size] + \
                      flat_tensor_offsets[offset:offset + size]

            # Combined indices contains an array for each reducible set of columns
            combined_indices.append(indices)
            offset += size

        return combined_indices, tf.concat(unique_tensors, 1)

    @utils.lru_cache
    def _compute_value_common(self, *value_tensors, padding_value=0.0):
        """Common actions when computing value."""
        # Check inputs
        if not self._values:
            raise StructureError("%s is missing input values." % self)
        # Prepare values
        if self._num_prods > 1:
            indices, value_tensor = self._combine_values_and_indices(value_tensors)
            # Create a 3D tensor with dimensions [batch, num-prods, max-prod-input-sizes]
            # The last axis will have zeros or ones (for log or non-log) when the
            # prod-input-size < max-prod-input-sizes
            reducible_values = utils.gather_cols_3d(value_tensor, indices,
                                                    pad_elem=padding_value)
            return reducible_values
        else:
            # Gather input tensors
            value_tensors = self._gather_input_tensors(*value_tensors)
            return tf.concat(value_tensors, 1)

    @utils.lru_cache
    def _compute_log_value(self, *value_tensors):
        values = self._compute_value_common(*value_tensors, padding_value=0.0)

        # Wrap the log value with its custom gradient
        @tf.custom_gradient
        def log_value(*unique_tensors):
            # Defines gradient for the log value
            def gradient(gradients):
                scattered_grads = self._compute_log_mpe_path(gradients, *value_tensors)
                return [sg for sg in scattered_grads if sg is not None]
            return tf.reduce_sum(values, axis=-1, keepdims=(False if self._num_prods > 1
                                                             else True)), gradient

        unique_tensors = self._get_differentiable_inputs(*value_tensors)
        if conf.custom_gradient:
            return log_value(*unique_tensors)
        else:
            return tf.reduce_sum(
                values, axis=-1, keep_dims=(False if self._num_prods > 1 else True))

    @utils.lru_cache
    def _get_differentiable_inputs(self, *value_tensors):
        unique_tensors = list(OrderedDict.fromkeys(value_tensors))
        return unique_tensors

    @utils.lru_cache
    def _compute_log_mpe_value(self, *value_tensors):
        return self._compute_log_value(*value_tensors)

    def _collect_count_indices_per_input(self):
        """
        For every unique (input, index) pair in the node's values list, collects
        and returns all column-indices of the counts tensor, for which the unique
        pair is a child of.
        """
        # Create a list of each input, paired with all the indices assosiated
        # with it
        # Eg: self._values = [(A, [0, 2, 3]),
        #                     (B, 1),
        #                     (A, None),
        #                     (B, [1, 2])]
        # expanded_inputs_list = [(A, 0), (A, 2), (A, 3),
        #                         (B, 1),
        #                         (A, 0), (A, 1), (A, 2), (A, 3),
        #                         (B, 1), (B, 2)]
        expanded_inputs_list = []
        for inp in self._values:
            if inp.indices is None:
                for i in range(inp.node.get_out_size()):
                    expanded_inputs_list.append((inp.node, i))
            elif isinstance(inp.indices, list):
                for i in inp.indices:
                    expanded_inputs_list.append((inp.node, i))
            elif isinstance(inp.indices, int):
                expanded_inputs_list.append((inp.node, inp.indices))

        # Create a list grouping together all inputs to each product modelled
        # Eg: self._prod_input_sizes = [2, 3, 2, 1, 2]
        #     prod_inputs_lists = [[(A, 0), (A, 2)],        # Prod-0
        #                          [(A, 3), (B, 1),(A, 0)], # Prod-1
        #                          [(A, 1), (A, 2)],        # Prod-2
        #                          [(A, 3)],                # Prod-3
        #                          [(B, 1), (B, 2)]]        # Prod-4
        prod_input_sizes = np.cumsum(np.array(self._prod_input_sizes)).tolist()
        prod_input_sizes.insert(0, 0)
        prod_inputs_lists = [expanded_inputs_list[start:stop] for start, stop in
                             zip(prod_input_sizes[:-1], prod_input_sizes[1:])]

        # Create a dictionary with each unique input and index pair as it's key,
        # and a list of product-indices as the corresponding value
        # Eg: unique_inps_inds_dict = {(A, 0): [0, 1], # Prod-0 and  Prod-1
        #                              (A, 1): [2],    # Prod-2
        #                              (A, 2): [0, 2], # Prod-0 and  Prod-2
        #                              (A, 3): [1],    # Prod-1
        #                              (B, 1): [1, 4], # Prod-1 and  Prod-4
        #                              (B, 2): [4]}    # Prod-4
        unique_inps_inds = defaultdict(list)
        for idx, inps in enumerate(prod_inputs_lists):
            for inp in inps:
                unique_inps_inds[inp] += [idx]

        # Sort dictionary based on key - Sorting ensures avoiding scatter op when
        # the original inputs is passed without indices
        unique_inps_inds = OrderedDict(sorted(unique_inps_inds.items()))

        # Collect all product indices as a nested list of indices to gather from
        # counts tensor
        # Eg: gather_counts_indices = [[0, 1],
        #                              [2],
        #                              [0, 2],
        #                              [1],
        #                              [1, 4],
        #                              [4]]
        gather_counts_indices = [v for v in unique_inps_inds.values()]

        # Create an ordered dictionary of unique inputs to this node as key,
        # and a list of unique indices per input as the corresponding value
        # Eg: unique_inps = {A: [0, 1, 2, 3]
        #                    B: [1, 2]}
        unique_inps = OrderedDict()
        for inp, ind in unique_inps_inds.keys():
            unique_inps[inp] = []
        for inp, ind in unique_inps_inds.keys():
            unique_inps[inp] += [ind]

        return gather_counts_indices, unique_inps

    @utils.lru_cache
    def _compute_log_mpe_path(self, counts, *value_values,
                              use_unweighted=False, sample=False, sample_prob=None):
        # Check inputs
        if not self._values:
            raise StructureError("%s is missing input values." % self)

        # For each unique (input, index) pair in the values list, collect counts
        # index of all counts for which the pair is a child of
        gather_counts_indices, unique_inputs = self._collect_count_indices_per_input()

        if self._num_prods > 1:
            # Gather columns from the counts tensor, per unique (input, index) pair
            reducible_values = utils.gather_cols_3d(counts, gather_counts_indices)

            # Sum gathered counts together per unique (input, index) pair
            summed_counts = tf.reduce_sum(reducible_values, axis=-1)
        else:
            # Calculate total inputs size
            inputs_size = sum([v_input.get_size(v_value) for v_input, v_value in
                               zip(self._values, value_values)])

            # Tile counts only if input is larger than 1
            summed_counts = (tf.tile(counts, [1, inputs_size]) if inputs_size > 1
                             else counts)

        # For each unique input in the values list, calculate the number of
        # unique indices
        unique_inp_sizes = [len(v) for v in unique_inputs.values()]

        # Split the summed-counts tensor per unique input, based on input-sizes
        unique_input_counts = tf.split(summed_counts, unique_inp_sizes, axis=-1) \
            if len(unique_inp_sizes) > 1 else [summed_counts]

        # Scatter each unique-counts tensor to the respective input, only once
        # per unique input in the values list
        scattered_counts = [None] * len(self._values)
        for (node, inds), cnts in zip(unique_inputs.items(), unique_input_counts):
            for i, (inp, val) in enumerate(zip(self._values, value_values)):
                if inp.node == node:
                    scattered_counts[i] = utils.scatter_cols(
                        cnts, inds, int(val.get_shape()[0 if val.get_shape().ndims
                                                        == 1 else 1]))
                    break

        return scattered_counts

    def _compute_log_gradient(self, gradients, *value_values):
        return self._compute_log_mpe_path(gradients, *value_values)

    def disconnect_inputs(self):
        self._values = None

    @property
    def is_layer(self):
        return True
Beispiel #10
0
class Sum(OpNode):
    """A node representing a single sum in an SPN.

    Args:
        *values (input_like): Inputs providing input values to this node.
            See :meth:`~libspn.Input.as_input` for possible values.
        weights (input_like): Input providing weights node to this sum node.
            See :meth:`~libspn.Input.as_input` for possible values. If set
            to ``None``, the input is disconnected.
        ivs (input_like): Input providing IVs of an explicit latent variable
            associated with this sum node. See :meth:`~libspn.Input.as_input`
            for possible values. If set to ``None``, the input is disconnected.
        name (str): Name of the node.

    Attributes:
        inference_type(InferenceType): Flag indicating the preferred inference
                                       type for this node that will be used
                                       during value calculation and learning.
                                       Can be changed at any time and will be
                                       used during the next inference/learning
                                       op generation.
    """

    __logger = get_logger()
    __info = __logger.info

    def __init__(self, *values, weights=None, ivs=None,
                 inference_type=InferenceType.MARGINAL, name="Sum"):
        super().__init__(inference_type, name)
        self.set_values(*values)
        self.set_weights(weights)
        self.set_ivs(ivs)

    def serialize(self):
        data = super().serialize()
        data['values'] = [(i.node.name, i.indices) for i in self._values]
        if self._weights:
            data['weights'] = (self._weights.node.name, self._weights.indices)
        if self._ivs:
            data['ivs'] = (self._ivs.node.name, self._ivs.indices)
        return data

    def deserialize(self, data):
        super().deserialize(data)
        self.set_values()
        self.set_weights()
        self.set_ivs()

    def deserialize_inputs(self, data, nodes_by_name):
        super().deserialize_inputs(data, nodes_by_name)
        self._values = tuple(Input(nodes_by_name[nn], i)
                             for nn, i in data['values'])
        weights = data.get('weights', None)
        if weights:
            self._weights = Input(nodes_by_name[weights[0]], weights[1])
        ivs = data.get('ivs', None)
        if ivs:
            self._ivs = Input(nodes_by_name[ivs[0]], ivs[1])

    @property
    @utils.docinherit(OpNode)
    def inputs(self):
        return (self._weights, self._ivs) + self._values

    @property
    def weights(self):
        """Input: Weights input."""
        return self._weights

    def set_weights(self, weights=None):
        """Set the weights input.

        Args:
            weights (input_like): Input providing weights node to this sum node.
                See :meth:`~libspn.Input.as_input` for possible values. If set
                to ``None``, the input is disconnected.
        """
        weights, = self._parse_inputs(weights)
        if weights and not isinstance(weights.node, Weights):
            raise StructureError("%s is not Weights" % weights.node)
        self._weights = weights

    @property
    def ivs(self):
        """Input: IVs input."""
        return self._ivs

    def set_ivs(self, ivs=None):
        """Set the IVs input.

        ivs (input_like): Input providing IVs of an explicit latent variable
            associated with this sum node. See :meth:`~libspn.Input.as_input`
            for possible values. If set to ``None``, the input is disconnected.
        """
        self._ivs, = self._parse_inputs(ivs)

    @property
    def values(self):
        """list of Input: List of value inputs."""
        return self._values

    def set_values(self, *values):
        """Set the inputs providing input values to this node. If no arguments
        are given, all existing value inputs get disconnected.

        Args:
            *values (input_like): Inputs providing input values to this node.
                See :meth:`~libspn.Input.as_input` for possible values.
        """
        self._values = self._parse_inputs(*values)

    def add_values(self, *values):
        """Add more inputs providing input values to this node.

        Args:
            *values (input_like): Inputs providing input values to this node.
                See :meth:`~libspn.Input.as_input` for possible values.
        """
        self._values = self._values + self._parse_inputs(*values)

    def generate_weights(self, init_value=1, trainable=True,
                         input_sizes=None, name=None):
        """Generate a weights node matching this sum node and connect it to
        this sum.

        The function calculates the number of weights based on the number
        of input values of this sum. Therefore, weights should be generated
        once all inputs are added to this node.

        Args:
            init_value: Initial value of the weights. For possible values, see
                :meth:`~libspn.utils.broadcast_value`.
            trainable (bool): See :class:`~libspn.Weights`.
            input_sizes (list of int): Pre-computed sizes of each input of
                this node.  If given, this function will not traverse the graph
                to discover the sizes.
            name (str): Name of the weighs node. If ``None`` use the name of the
                        sum + ``_Weights``.

        Return:
            Weights: Generated weights node.
        """
        if not self._values:
            raise StructureError("%s is missing input values" % self)
        if name is None:
            name = self._name + "_Weights"
        # Count all input values
        if not input_sizes:
            input_sizes = self.get_input_sizes()
        num_values = sum(input_sizes[2:])  # Skip ivs, weights
        # Generate weights
        weights = Weights(init_value=init_value,
                          num_weights=num_values,
                          trainable=trainable, name=name)
        self.set_weights(weights)
        return weights

    def generate_ivs(self, feed=None, name=None):
        """Generate an IVs node matching this sum node and connect it to
        this sum.

        IVs should be generated once all inputs are added to this node,
        otherwise the number of IVs will be incorrect.

        Args:
            feed (Tensor): See :class:`~libspn.IVs`.
            name (str): Name of the IVs node. If ``None`` use the name of the
                        sum + ``_IVs``.

        Return:
            IVs: Generated IVs node.
        """
        if not self._values:
            raise StructureError("%s is missing input values" % self)
        if name is None:
            name = self._name + "_IVs"
        # Count all input values
        num_values = sum(len(v.indices) if v.indices is not None
                         else v.node.get_out_size()
                         for v in self._values)
        ivs = IVs(feed=feed, num_vars=1, num_vals=num_values, name=name)
        self.set_ivs(ivs)
        return ivs

    @property
    def _const_out_size(self):
        return True

    def _compute_out_size(self, *input_out_sizes):
        return 1

    def _compute_scope(self, weight_scopes, ivs_scopes, *value_scopes):
        if not self._values:
            raise StructureError("%s is missing input values" % self)
        _, ivs_scopes, *value_scopes = self._gather_input_scopes(weight_scopes,
                                                                 ivs_scopes,
                                                                 *value_scopes)
        flat_value_scopes = list(chain.from_iterable(value_scopes))
        if self._ivs:
            flat_value_scopes.extend(ivs_scopes)
        return [Scope.merge_scopes(flat_value_scopes)]

    def _compute_valid(self, weight_scopes, ivs_scopes, *value_scopes):
        if not self._values:
            raise StructureError("%s is missing input values" % self)
        _, ivs_scopes_, *value_scopes_ = self._gather_input_scopes(weight_scopes,
                                                                   ivs_scopes,
                                                                   *value_scopes)
        # If already invalid, return None
        if (any(s is None for s in value_scopes_)
                or (self._ivs and ivs_scopes_ is None)):
            return None
        flat_value_scopes = list(chain.from_iterable(value_scopes_))
        # IVs
        if self._ivs:
            # Verify number of IVs
            if len(ivs_scopes_) != len(flat_value_scopes):
                raise StructureError("Number of IVs (%s) and values (%s) does "
                                     "not match for %s"
                                     % (len(ivs_scopes_), len(flat_value_scopes),
                                        self))
            # Check if scope of all IVs is just one and the same variable
            if len(Scope.merge_scopes(ivs_scopes_)) > 1:
                return None
        # Check sum for completeness wrt values
        first_scope = flat_value_scopes[0]
        if any(s != first_scope for s in flat_value_scopes[1:]):
            self.__info("%s is not complete with input value scopes %s",
                        self, flat_value_scopes)
            return None
        return self._compute_scope(weight_scopes, ivs_scopes, *value_scopes)

    def _compute_value_common(self, weight_tensor, ivs_tensor, *value_tensors):
        """Common actions when computing value."""
        # Check inputs
        if not self._values:
            raise StructureError("%s is missing input values" % self)
        if not self._weights:
            raise StructureError("%s is missing weights" % self)
        # Prepare values
        weight_tensor, ivs_tensor, *value_tensors = self._gather_input_tensors(
            weight_tensor, ivs_tensor, *value_tensors)
        values = utils.concat_maybe(value_tensors, 1)
        return weight_tensor, ivs_tensor, values

    def _compute_value(self, weight_tensor, ivs_tensor, *value_tensors):
        weight_tensor, ivs_tensor, values = self._compute_value_common(
            weight_tensor, ivs_tensor, *value_tensors)
        values_selected = values * ivs_tensor if self._ivs else values
        return tf.matmul(values_selected, tf.reshape(weight_tensor, [-1, 1]))

    def _compute_log_value(self, weight_tensor, ivs_tensor, *value_tensors):
        weight_tensor, ivs_tensor, values = self._compute_value_common(
            weight_tensor, ivs_tensor, *value_tensors)
        values_selected = values + ivs_tensor if self._ivs else values
        values_weighted = values_selected + weight_tensor
        return utils.reduce_log_sum(values_weighted)

    def _compute_mpe_value(self, weight_tensor, ivs_tensor, *value_tensors):
        weight_tensor, ivs_tensor, values = self._compute_value_common(
            weight_tensor, ivs_tensor, *value_tensors)
        values_selected = values * ivs_tensor if self._ivs else values
        values_weighted = values_selected * weight_tensor
        return tf.reduce_max(values_weighted, 1, keep_dims=True)

    def _compute_log_mpe_value(self, weight_tensor, ivs_tensor, *value_tensors):
        weight_tensor, ivs_tensor, values = self._compute_value_common(
            weight_tensor, ivs_tensor, *value_tensors)
        values_selected = values + ivs_tensor if self._ivs else values
        values_weighted = values_selected + weight_tensor
        return tf.reduce_max(values_weighted, 1, keep_dims=True)

    def _compute_mpe_path_common(self, values_weighted, counts, weight_value,
                                 ivs_value, *value_values):
        # Propagate the counts to the max value
        max_indices = tf.argmax(values_weighted, dimension=1)
        max_counts = tf.one_hot(max_indices,
                                values_weighted.get_shape()[1]) * counts
        # Split the counts to value inputs
        _, _, *value_sizes = self.get_input_sizes(None, None, *value_values)
        max_counts_split = utils.split_maybe(max_counts, value_sizes, 1)
        return self._scatter_to_input_tensors(
            (max_counts, weight_value),  # Weights
            (max_counts, ivs_value),  # IVs
            *[(t, v) for t, v in zip(max_counts_split, value_values)])  # Values

    def _compute_mpe_path(self, counts, weight_value, ivs_value, *value_values,
                          add_random=None, use_unweighted=False):
        # Get weighted, IV selected values
        weight_value, ivs_value, values = self._compute_value_common(
            weight_value, ivs_value, *value_values)
        values_selected = values * ivs_value if self._ivs else values
        values_weighted = values_selected * weight_value
        return self._compute_mpe_path_common(
            values_weighted, counts, weight_value, ivs_value, *value_values)

    def _compute_log_mpe_path(self, counts, weight_value, ivs_value, *value_values,
                              add_random=None, use_unweighted=False):
        # Get weighted, IV selected values
        weight_value, ivs_value, values = self._compute_value_common(
            weight_value, ivs_value, *value_values)
        values_selected = values + ivs_value if self._ivs else values

        # WARN USING UNWEIGHTED VALUE
        if not use_unweighted or any(v.node.is_var for v in self._values):
            values_weighted = values_selected + weight_value
        else:
            values_weighted = values_selected

        # / USING UNWEIGHTED VALUE

        # WARN ADDING RANDOM NUMBERS
        if add_random is not None:
            values_weighted = tf.add(values_weighted, tf.random_uniform(
                shape=(tf.shape(values_weighted)[0],
                       int(values_weighted.get_shape()[1])),
                minval=0, maxval=add_random,
                dtype=conf.dtype))
        # /ADDING RANDOM NUMBERS

        return self._compute_mpe_path_common(
            values_weighted, counts, weight_value, ivs_value, *value_values)
Beispiel #11
0
class DenseSPNGenerator:
    """Generates a dense SPN according to the algorithm described in
    Poon&Domingos UAI'11.

    Attributes:
        num_decomps (int): Number of decompositions at each level.
        num_subsets (int): Number of variable sub-sets for each decomposition.
        num_mixtures (int): Number of mixtures (sums) for each variable subset.
        input_dist (InputDist): Determines how inputs sharing the same scope
                                (for instance IndicatorLeaf for different values of a
                                random variable) should be included into the
                                generated structure.
        num_input_mixtures (int): Number of mixtures used for combining all
                                  inputs sharing scope when ``input_dist`` is
                                  set to ``MIXTURE``. If set to ``None``,
                                  ``num_mixtures`` is used.
        balanced (bool): Use only balanced decompositions, into subsets of
                         similar cardinality (differing by max 1).
        node_type (NodeType): Determines the type of op-node - single (Sum, Product),
                              block (ParallelSums, PermuteProducts) or layer (SumsLayer,
                              ProductsLayer) - to be used in the generated structure.

    """

    __logger = get_logger()
    __debug1 = __logger.debug1
    __debug2 = __logger.debug2
    __debug3 = __logger.debug3

    class InputDist(Enum):
        """Determines how inputs sharing the same scope (for instance IndicatorLeaf for
        different values of a random variable) should be included into the
        generated structure."""

        RAW = 0
        """Each input is considered a different distribution over the scope and
        used directly instead of a mixture as an input to product nodes for
        singleton variable subsets."""

        MIXTURE = 1
        """``input_num_mixtures`` mixtures are created over all the inputs
        sharing a scope, effectively creating ``input_num_mixtures``
        distributions over the scope. These mixtures are then used as inputs
        to product nodes for singleton variable subsets."""

    class NodeType(Enum):
        """Determines the type of op-node - single (Sum, Product), block (ParallelSums,
        PermuteProducts) or layer (SumsLayer, ProductsLayer) - to be used in the
        generated structure."""

        SINGLE = 0
        BLOCK = 1
        LAYER = 2

    class SubsetInfo:
        """Stores information about a single subset to be decomposed.

        Attributes:
            level(int): Number of the SPN layer where the subset is decomposed.
            subset(list of tuple of tuple): Subset of inputs to decompose
                                            grouped by scope.
            parents(list of Sum): List of sum nodes mixing the outputs of the
                                  generated decompositions. Should be the root
                                  node at the very top.
        """

        def __init__(self, level, subset, parents):
            self.level = level
            self.subset = subset
            self.parents = parents

    def __init__(self, num_decomps, num_subsets, num_mixtures,
                 input_dist=InputDist.MIXTURE, num_input_mixtures=None,
                 balanced=True, node_type=NodeType.LAYER):
        # Args
        if not isinstance(num_decomps, int) or num_decomps < 1:
            raise ValueError("num_decomps must be a positive integer")
        if not isinstance(num_subsets, int) or num_subsets < 1:
            raise ValueError("num_subsets must be a positive integer")
        if not isinstance(num_mixtures, int) or num_mixtures < 1:
            raise ValueError("num_mixtures must be a positive integer")
        if input_dist not in DenseSPNGenerator.InputDist:
            raise ValueError("Incorrect input_dist: %s", input_dist)
        if (num_input_mixtures is not None and
                (not isinstance(num_input_mixtures, int)
                 or num_input_mixtures < 1)):
            raise ValueError("num_input_mixtures must be None"
                             " or a positive integer")

        # Attributes
        self.num_decomps = num_decomps
        self.num_subsets = num_subsets
        self.num_mixtures = num_mixtures
        self.input_dist = input_dist
        self.balanced = balanced
        self.node_type = node_type
        if num_input_mixtures is None:
            self.num_input_mixtures = num_mixtures
        else:
            self.num_input_mixtures = num_input_mixtures

        # Stirling numbers and ratios for partition sampling
        self.__stirling = utils.Stirling()

    def generate(self, *inputs, rnd=None, root_name=None):
        """Generate the SPN.

        Args:
            inputs (input_like): Inputs to the generated SPN.
            rnd (Random): Optional. A custom instance of a random number generator
                          ``random.Random`` that will be used instead of the
                          default global instance. This permits using a generator
                          with a custom state independent of the global one.
            root_name (str): Name of the root node of the generated SPN.

        Returns:
           Sum: Root node of the generated SPN.
        """
        self.__debug1(
            "Generating dense SPN (num_decomps=%s, num_subsets=%s,"
            " num_mixtures=%s, input_dist=%s, num_input_mixtures=%s)",
            self.num_decomps, self.num_subsets,
            self.num_mixtures, self.input_dist, self.num_input_mixtures)
        inputs = [Input.as_input(i) for i in inputs]
        input_set = self.__generate_set(inputs)
        self.__debug1("Found %s distinct input scopes",
                      len(input_set))

        # Create root
        root = Sum(name=root_name)

        # Subsets left to process
        subsets = deque()
        subsets.append(DenseSPNGenerator.SubsetInfo(level=1,
                                                              subset=input_set,
                                                              parents=[root]))

        # Process subsets layer by layer
        self.__decomp_id = 1  # Id number of a decomposition, for info only
        while subsets:
            # Process whole layer (all subsets at the same level)
            level = subsets[0].level
            self.__debug1("Processing level %s", level)
            while subsets and subsets[0].level == level:
                subset = subsets.popleft()
                new_subsets = self.__add_decompositions(subset, rnd)
                for s in new_subsets:
                    subsets.append(s)

        # If NodeType is LAYER, convert the generated graph with LayerNodes
        return (self.convert_to_layer_nodes(root) if self.node_type ==
                DenseSPNGenerator.NodeType.LAYER else root)

    def __generate_set(self, inputs):
        """Generate a set of inputs to the generated SPN grouped by scope.

        Args:
            inputs (list of Input): List of inputs.

        Returns:
           list of tuple of tuple: A list where each elements is a tuple of
               all inputs to the generated SPN which share the same scope.
               Each of that scopes is guaranteed to be unique. That tuple
               contains tuples ``(node, index)`` which uniquely identify
               specific inputs.
        """
        scope_dict = {}  # Dict indexed by scope

        def add_input(scope, node, index):
            try:
                # Try appending to existing scope
                scope_dict[scope].add((node, index))
            except KeyError:
                # Scope not in dict, check if it overlaps with other scopes
                for s in scope_dict:
                    if s & scope:
                        raise StructureError("Differing scopes of inputs overlap")
                # Add to dict
                scope_dict[scope] = set([(node, index)])

        # Process inputs and group by scope
        for inpt in inputs:
            node_scopes = inpt.node.get_scope()
            if inpt.indices is None:
                for index, scope in enumerate(node_scopes):
                    add_input(scope, inpt.node, index)
            else:
                for index in inpt.indices:
                    add_input(node_scopes[index], inpt.node, index)

        # Convert to hashable tuples and sort
        # Sorting might improve performance due to branch prediction
        return [tuple(sorted(i)) for i in scope_dict.values()]

    def __add_decompositions(self, subset_info: SubsetInfo, rnd: random.Random):
        """Add nodes for a single subset, i.e. an instance of ``num_decomps``
        decompositions of ``subset`` into ``num_subsets`` sub-subsets with
        ``num_mixures`` mixtures per sub-subset.

        Args:
            subset_info(SubsetInfo): Info about the subset being decomposed.
            rnd (Random): A custom instance of a random number generator or
                          ``None`` if default global instance should be used.

        Returns:
            list of SubsetInfo: Info about each new generated subset, which
            requires further decomposition.
        """

        def subsubset_to_inputs_list(subsubset):
            """Convert sub-subsets into a list of tuples, where each
               tuple contains an input and a list of indices
            """
            subsubset_list = list(next(iter(subsubset)))
            # Create a list of unique inputs from sub-subsets list
            unique_inputs = list(set(s_subset[0]
                                 for s_subset in subsubset_list))
            # For each unique input, collect all associated indices
            # into a single list, then create a list of tuples,
            # where each tuple contains an unique input and it's
            # list of indices
            inputs_list = []
            for unique_inp in unique_inputs:
                indices_list = []
                for s_subset in subsubset_list:
                    if s_subset[0] == unique_inp:
                        indices_list.append(s_subset[1])
                inputs_list.append(tuple((unique_inp, indices_list)))

            return inputs_list

        # Get subset partitions
        self.__debug3("Decomposing subset:\n%s", subset_info.subset)
        num_elems = len(subset_info.subset)
        num_subsubsets = min(num_elems, self.num_subsets)  # Requested num subsets
        partitions = utils.random_partitions(subset_info.subset, num_subsubsets,
                                             self.num_decomps,
                                             balanced=self.balanced,
                                             rnd=rnd,
                                             stirling=self.__stirling)
        self.__debug2("Randomized %s decompositions of a subset"
                      " of %s elements into %s sets",
                      len(partitions), num_elems, num_subsubsets)

        # Generate nodes for each decomposition/partition
        subsubset_infos = []
        for part in partitions:
            self.__debug2("Decomposition %s: into %s subsubsets of cardinality %s",
                          self.__decomp_id, len(part), [len(s) for s in part])
            self.__debug3("Decomposition %s subsubsets:\n%s",
                          self.__decomp_id, part)
            # Handle each subsubset
            sums_id = 1
            prod_inputs = []
            for subsubset in part:
                if self.node_type == DenseSPNGenerator.NodeType.SINGLE:
                    # Use single-nodes
                    if len(subsubset) > 1:  # Decomposable further
                        # Add mixtures
                        with tf.name_scope("Sums%s.%s" % (self.__decomp_id, sums_id)):
                            sums = [Sum(name="Sum%s" % (i + 1))
                                    for i in range(self.num_mixtures)]
                            sums_id += 1
                        # Register the mixtures as inputs of products
                        prod_inputs.append([(s, 0) for s in sums])
                        # Generate subsubset info
                        subsubset_infos.append(DenseSPNGenerator.SubsetInfo(
                            level=subset_info.level + 1, subset=subsubset,
                            parents=sums))
                    else:  # Non-decomposable
                        if self.input_dist == DenseSPNGenerator.InputDist.RAW:
                            # Register the content of subset as inputs to products
                            prod_inputs.append(next(iter(subsubset)))
                        elif self.input_dist == DenseSPNGenerator.InputDist.MIXTURE:
                            # Add mixtures
                            with tf.name_scope("Sums%s.%s" % (self.__decomp_id, sums_id)):
                                sums = [Sum(name="Sum%s" % (i + 1))
                                        for i in range(self.num_input_mixtures)]
                                sums_id += 1
                            # Register the mixtures as inputs of products
                            prod_inputs.append([(s, 0) for s in sums])
                            # Create an inputs list
                            inputs_list = subsubset_to_inputs_list(subsubset)
                            # Connect inputs to mixtures
                            for s in sums:
                                s.add_values(*inputs_list)
                else:  # Use multi-nodes
                    if len(subsubset) > 1:  # Decomposable further
                        # Add mixtures
                        with tf.name_scope("Sums%s.%s" % (self.__decomp_id, sums_id)):
                            sums = ParallelSums(num_sums=self.num_mixtures,
                                                name="ParallelSums%s.%s" %
                                           (self.__decomp_id, sums_id))
                            sums_id += 1
                        # Register the mixtures as inputs of PermProds
                        prod_inputs.append(sums)
                        # Generate subsubset info
                        subsubset_infos.append(DenseSPNGenerator.SubsetInfo(
                                               level=subset_info.level + 1,
                                               subset=subsubset, parents=[sums]))
                    else:  # Non-decomposable
                        if self.input_dist == DenseSPNGenerator.InputDist.RAW:
                            # Create an inputs list
                            inputs_list = subsubset_to_inputs_list(subsubset)
                            if len(inputs_list) > 1:
                                inputs_list = [Concat(*inputs_list)]
                            # Register the content of subset as inputs to PermProds
                            [prod_inputs.append(inp) for inp in inputs_list]
                        elif self.input_dist == DenseSPNGenerator.InputDist.MIXTURE:
                            # Create an inputs list
                            inputs_list = subsubset_to_inputs_list(subsubset)
                            # Add mixtures
                            with tf.name_scope("Sums%s.%s" % (self.__decomp_id, sums_id)):
                                sums = ParallelSums(*inputs_list,
                                                    num_sums=self.num_input_mixtures,
                                                    name="ParallelSums%s.%s" %
                                               (self.__decomp_id, sums_id))
                                sums_id += 1
                            # Register the mixtures as inputs of PermProds
                            prod_inputs.append(sums)
            # Add product nodes
            if self.node_type == DenseSPNGenerator.NodeType.SINGLE:
                products = self.__add_products(prod_inputs)
            else:
                products = ([PermuteProducts(*prod_inputs, name="PermuteProducts%s" % self.__decomp_id)]
                            if len(prod_inputs) > 1 else prod_inputs)
            # Connect products to each parent Sum
            for p in subset_info.parents:
                p.add_values(*products)
            # Increment decomposition id
            self.__decomp_id += 1
        return subsubset_infos

    def __add_products(self, prod_inputs):
        """
        Add product nodes for a single decomposition and connect them to their
        input nodes.

        Args:
            prod_inputs (list of list of Node): A list of lists of nodes
                being inputs to the products, grouped by scope.

        Returns:
            list of Product: A list of product nodes.
        """
        selected = [0 for _ in prod_inputs]  # Input selected for each scope
        cont = True
        products = []
        product_num = 1
        with tf.name_scope("Products%s" % self.__decomp_id):
            while cont:
                if len(prod_inputs) > 1:
                    # Add a product node
                    products.append(Product(*[pi[s] for (pi, s) in
                                              zip(prod_inputs, selected)],
                                            name="Product%s" % product_num))
                else:
                    products.append(*[pi[s] for (pi, s) in
                                      zip(prod_inputs, selected)])
                product_num += 1
                # Increment selected
                cont = False
                for i, group in enumerate(prod_inputs):
                    if selected[i] < len(group) - 1:
                        selected[i] += 1
                        for j in range(i):
                            selected[j] = 0
                        cont = True
                        break
        return products

    def convert_to_layer_nodes(self, root):
        """
        At each level in the SPN rooted in the 'root' node, model all the nodes
        as a single layer-node.

        Args:
            root (Node): The root of the SPN graph.

        Returns:
            root (Node): The root of the SPN graph, with each layer modelled as a
                         single layer-node.
        """

        parents = defaultdict(list)
        depths = defaultdict(list)
        node_to_depth = OrderedDict()
        node_to_depth[root] = 1

        def get_parents(node):
            # Add to Parents dict
            if node.is_op:
                for i in node.inputs:
                    if (i and  # Input not empty
                            not(i.is_param or i.is_var or
                                isinstance(i.node, (SumsLayer, ProductsLayer, ConvSums,
                                                    ConvProducts, Concat, LocalSums)))):
                        parents[i.node].append(node)
                        node_to_depth[i.node] = node_to_depth[node] + 1

        def permute_inputs(input_values, input_sizes):
            # For a given list of inputs and their corresponding sizes, create a
            # nested-list of (input, index) pairs.
            # E.g: input_values = [(A, [2, 5]), (B, None)]
            #      input_sizes = [2, 3]
            #      inputs = [[('A', 2), ('A', 5)],
            #                [('B', 0), ('B', 1), ('B', 2)]]
            inputs = [list(product([inp.node], inp.indices)) if inp and inp.indices
                      else list(product([inp.node], list(range(inp_size)))) for
                      inp, inp_size in zip(input_values, input_sizes)]

            # For a given nested-list of (input, index) pairs, permute over the inputs
            # E.g: permuted_inputs = [('A', 2), ('B', 0),
            #                         ('A', 2), ('B', 1),
            #                         ('A', 2), ('B', 2),
            #                         ('A', 5), ('B', 0),
            #                         ('A', 5), ('B', 1),
            #                         ('A', 5), ('B', 2)]
            permuted_inputs = list(product(*[inps for inps in inputs]))
            return list(chain(*permuted_inputs))

        # Create a parents dictionary of the SPN graph
        traverse_graph(root, fun=get_parents, skip_params=True)

        # Create a depth dictionary of the SPN graph
        for key, value in node_to_depth.items():
            depths[value].append(key)
        spn_depth = len(depths)

        # Iterate through each depth of the SPN, starting from the deepest layer,
        # moving up to the root node
        for depth in range(spn_depth, 1, -1):
            if isinstance(depths[depth][0], (Sum, ParallelSums)):  # A Sums Layer
                # Create a default SumsLayer node
                with tf.name_scope("Layer%s" % depth):
                    sums_layer = SumsLayer(name="SumsLayer-%s.%s" % (depth, 1))
                # Initialize a counter for keeping track of number of sums
                # modelled in the layer node
                layer_num_sums = 0
                # Initialize an empty list for storing sum-input-sizes of sums
                # modelled in the layer node
                num_or_size_sums = []
                # Iterate through each node at the current depth of the SPN
                for node in depths[depth]:
                    # TODO: To be replaced with node.num_sums once AbstractSums
                    # class is introduced
                    # No. of sums modelled by the current node
                    node_num_sums = (1 if isinstance(node, Sum) else node.num_sums)
                    # Add Input values of the current node to the SumsLayer node
                    sums_layer.add_values(*node.values * node_num_sums)
                    # Add sum-input-size, of each sum modelled in the current node,
                    # to the list
                    num_or_size_sums += [sum(node.get_input_sizes()[2:])] * node_num_sums
                    # Visit each parent of the current node
                    for parent in parents[node]:
                        try:
                            # 'Values' in case parent is an Op node
                            values = list(parent.values)
                        except AttributeError:
                            # 'Inputs' in case parent is a Concat node
                            values = list(parent.inputs)
                        # Iterate through each input value of the current parent node
                        for i, value in enumerate(values):
                            # If the value is the current node
                            if value.node == node:
                                # Check if it has indices
                                if value.indices is not None:
                                    # If so, then just add the num-sums of the
                                    # layer-op as offset
                                    indices = (np.asarray(value.indices) +
                                               layer_num_sums).tolist()
                                else:
                                    # If not, then create a list accrodingly
                                    indices = list(range(layer_num_sums,
                                                         (layer_num_sums +
                                                          node_num_sums)))
                                # Replace previous (node) Input value in the
                                # current parent node, with the new layer-node value
                                values[i] = (sums_layer, indices)
                                break  # Once child-node found, don't have to search further
                        # Reset values of the current parent node, by including
                        # the new child (Layer-node)
                        try:
                            # set 'values' in case parent is an Op node
                            parent.set_values(*values)
                        except AttributeError:
                            # set 'inputs' in case parent is a Concat node
                            parent.set_inputs(*values)
                    # Increment num-sums-counter of the layer-node
                    layer_num_sums += node_num_sums
                # After all nodes at a certain depth are modelled into a Layer-node,
                # set num-sums parameter accordingly
                sums_layer.set_sum_sizes(num_or_size_sums)
            elif isinstance(depths[depth][0], (Product, PermuteProducts)):  # A Products Layer
                with tf.name_scope("Layer%s" % depth):
                    prods_layer = ProductsLayer(name="ProductsLayer-%s.%s" % (depth, 1))
                # Initialize a counter for keeping track of number of prods
                # modelled in the layer node
                layer_num_prods = 0
                # Initialize an empty list for storing prod-input-sizes of prods
                # modelled in the layer node
                num_or_size_prods = []
                # Iterate through each node at the current depth of the SPN
                for node in depths[depth]:
                    # Get input values and sizes of the product node
                    input_values = list(node.values)
                    input_sizes = list(node.get_input_sizes())
                    if isinstance(node, PermuteProducts):
                        # Permute over input-values to model permuted products
                        input_values = permute_inputs(input_values, input_sizes)
                        node_num_prods = node.num_prods
                        prod_input_size = len(input_values) // node_num_prods
                    elif isinstance(node, Product):
                        node_num_prods = 1
                        prod_input_size = int(sum(input_sizes))

                    # Add Input values of the current node to the ProductsLayer node
                    prods_layer.add_values(*input_values)
                    # Add prod-input-size, of each product modelled in the current
                    # node, to the list
                    num_or_size_prods += [prod_input_size] * node_num_prods
                    # Visit each parent of the current node
                    for parent in parents[node]:
                        values = list(parent.values)
                        # Iterate through each input value of the current parent node
                        for i, value in enumerate(values):
                            # If the value is the current node
                            if value.node == node:
                                # Check if it has indices
                                if value.indices is not None:
                                    # If so, then just add the num-prods of the
                                    # layer-op as offset
                                    indices = value.indices + layer_num_prods
                                else:
                                    # If not, then create a list accrodingly
                                    indices = list(range(layer_num_prods,
                                                         (layer_num_prods +
                                                          node_num_prods)))
                                # Replace previous (node) Input value in the
                                # current parent node, with the new layer-node value
                                values[i] = (prods_layer, indices)
                        # Reset values of the current parent node, by including
                        # the new child (Layer-node)
                        parent.set_values(*values)
                    # Increment num-prods-counter of the layer node
                    layer_num_prods += node_num_prods
                # After all nodes at a certain depth are modelled into a Layer-node,
                # set num-prods parameter accordingly
                prods_layer.set_prod_sizes(num_or_size_prods)
            elif isinstance(depths[depth][0], (SumsLayer, ProductsLayer, Concat)):  # A Concat node
                pass
            else:
                raise StructureError("Unknown node-type: {}".format(depths[depth][0]))

        return root
Beispiel #12
0
"""LibSVM session and running helpers."""

from contextlib import contextmanager
import tensorflow as tf
from libspn.log import get_logger

logger = get_logger()


@contextmanager
def session():
    """Context manager initializing and deinitializing the standard TensorFlow
    session infrastructure.

    Use like this::

        with spn.session() as (sess, run):
            while run():
                sess.run(something)
    """
    # Op for initializing variables
    # As stated in https://github.com/tensorflow/tensorflow/issues/3819
    # all does not include local (e.g. epoch counter)
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    # Create a session
    with tf.Session() as sess:
        # Initialize epoch counter and other variables
        sess.run(init_op)
Beispiel #13
0
class _ConvProdNaive(ProductsLayer):
    """A container representing convolutional products in an SPN.

    Args:
        *values (input_like): Inputs providing input values to this container.
            See :meth:`~libspn.Input.as_input` for possible values.
        num_sums (int): Number of Sum ops modelled by this container.
        weights (input_like): Input providing weights container to this sum container.
            See :meth:`~libspn.Input.as_input` for possible values. If set
            to ``None``, the input is disconnected.
        name (str): Name of the container.

    Attributes:
        inference_type(InferenceType): Flag indicating the preferred inference
                                       type for this container that will be used
                                       during value calculation and learning.
                                       Can be changed at any time and will be
                                       used during the next inference/learning
                                       op generation.
    """

    logger = get_logger()

    def __init__(self,
                 *values,
                 num_channels=None,
                 padding='valid',
                 dilation_rate=1,
                 strides=2,
                 kernel_size=2,
                 name="ConvProd2D",
                 sparse_connections=None,
                 dense_connections=None,
                 spatial_dim_sizes=None,
                 num_channels_max=512,
                 pad_top=None,
                 pad_bottom=None,
                 pad_left=None,
                 pad_right=None):
        self._batch_axis = 0
        self._channel_axis = 3

        self._pad_top, self._pad_bottom = pad_top, pad_bottom
        self._pad_left, self._pad_right = pad_left, pad_right

        super().__init__(name=name, inference_type=inference_type)
        self.set_values(*values)

        num_channels = min(num_channels or num_channels_max, num_channels_max)
        if spatial_dim_sizes is None:
            raise NotImplementedError(
                "{}: Must also provide spatial_dim_sizes at this point.".
                format(self))

        self._spatial_dim_sizes = spatial_dim_sizes or [-1] * 2
        if isinstance(self._spatial_dim_sizes, tuple):
            self._spatial_dim_sizes = list(self._spatial_dim_sizes)
        self._padding = padding
        self._dilation_rate = [dilation_rate] * 2 \
            if isinstance(dilation_rate, int) else list(dilation_rate)
        self._strides = [strides] * 2 \
            if isinstance(strides, int) else list(strides)
        self._num_channels = num_channels
        self._kernel_size = [kernel_size] * 2 if isinstance(kernel_size, int) \
            else list(kernel_size)

        if sparse_connections is not None:
            if dense_connections is not None:
                raise ValueError(
                    "{}: Must provide either spare connections or dense connections, "
                    "not both.".format(self))
            self._sparse_connections = sparse_connections
            self._dense_connections = self.sparse_connections_to_dense(
                sparse_connections)
        elif dense_connections is not None:
            self._dense_connections = dense_connections
            self._sparse_connections = self.dense_connections_to_sparse(
                dense_connections)
        else:
            self._sparse_connections = self.generate_sparse_connections(
                num_channels)
            self._dense_connections = self.sparse_connections_to_dense(
                self._sparse_connections)

        self._set_prod_sizes()

    @staticmethod
    def _pad_or_zero(pad):
        return 0 if pad is None else pad

    def _explicit_pad_sizes(self):
        # Replace any 'None' with 0
        pad_top = self._pad_or_zero(self._pad_top)
        pad_bottom = self._pad_or_zero(self._pad_bottom)
        pad_left = self._pad_or_zero(self._pad_left)
        pad_right = self._pad_or_zero(self._pad_right)
        return pad_bottom, pad_left, pad_right, pad_top

    def set_values(self, *values):
        """Set the inputs providing input values to this node. If no arguments
        are given, all existing value inputs get disconnected.

        Args:
            *values (input_like): Inputs providing input values to this node.
                See :meth:`~libspn.Input.as_input` for possible values.
        """
        if len(values) > 1:
            raise NotImplementedError("Currently only supports single input")
        self._values = self._parse_inputs(*values)

    def _set_prod_sizes(self):
        indices, prod_sizes = self._forward_gather_indices_from_kernel()
        self.set_values((self.values[0].node, indices))
        self.set_prod_sizes(prod_sizes)

    def _forward_gather_indices_from_kernel(self):
        pad_top, pad_bottom, pad_left, pad_right = self.pad_sizes()
        out_rows, out_cols = self.output_shape_spatial[:2]

        kstart_row = -pad_top

        num_inp_channels = self._num_input_channels()

        def sub2ind(row, col, channel):
            return int(row * self._spatial_dim_sizes[1] * num_inp_channels + \
                       col * num_inp_channels + \
                       channel)

        indices = []
        prod_sizes = []
        for out_r in range(out_rows):
            kstart_col = -pad_left
            for out_c in range(out_cols):
                for out_channel in range(self._num_channels):
                    ind_current_prod = []
                    for kernel_row in range(self._kernel_size[0]):
                        input_row = kstart_row + kernel_row * self._dilation_rate[
                            0]
                        for kernel_col in range(self._kernel_size[1]):
                            input_col = kstart_col + kernel_col * self._dilation_rate[
                                0]
                            if 0 <= input_row < self._spatial_dim_sizes[0] and \
                                    0 <= input_col < self._spatial_dim_sizes[1]:
                                channel = self._sparse_connections \
                                    [kernel_row, kernel_col, out_channel]
                                ind_current_prod.append(
                                    sub2ind(input_row, input_col, channel))

                    # End of output channel
                    indices.extend(ind_current_prod)
                    prod_sizes.append(len(ind_current_prod))

                # End of output column
                kstart_col += self._strides[1]
            # End of output row
            kstart_row += self._strides[0]

        return indices, prod_sizes

    def _preconv_grid_sizes(self):
        pad_top, pad_bottom, pad_left, pad_right = self.pad_sizes()
        return pad_top + pad_bottom + self._spatial_dim_sizes[0], \
               pad_left + pad_right + self._spatial_dim_sizes[1]

    def sparse_connections_to_dense(self, sparse):
        # Sparse has shape [rows, cols, out_channels]
        # Dense will have shape [rows, cols, in_channels, out_channels]
        num_input_channels = self._num_input_channels()
        batch_grid_prod = int(np.prod(self._kernel_size +
                                      [self._num_channels]))
        dense = np.zeros([batch_grid_prod, num_input_channels])

        dense[np.arange(batch_grid_prod), sparse.ravel()] = 1.0
        dense = dense.reshape(self._kernel_size +
                              [self._num_channels, num_input_channels])
        # [rows, cols, out_channels, in_channels] to [rows, cols, in_channels, out_channels]
        return np.transpose(dense,
                            (0, 1, 3, 2)).astype(conf.dtype.as_numpy_dtype())

    def dense_connections_to_sparse(self, dense):
        return np.argmax(dense, axis=2)

    def generate_sparse_connections(self, num_channels):
        num_input_channels = self._num_input_channels()
        kernel_surface = int(np.prod(self._kernel_size))
        total_possibilities = num_input_channels**kernel_surface
        if num_channels >= total_possibilities:
            if num_channels > total_possibilities:
                self.logger.warn(
                    "Number of channels exceeds total number of combinations.")
                self._num_channels = total_possibilities
            p = np.arange(total_possibilities)
            kernel_cells = []
            for _ in range(kernel_surface):
                kernel_cells.append(p % num_input_channels)
                p //= num_input_channels
            return np.stack(kernel_cells,
                            axis=0).reshape(self._kernel_size +
                                            [total_possibilities])

        sparse_shape = self._kernel_size + [num_channels]
        size = int(np.prod(sparse_shape))
        return np.random.randint(num_input_channels,
                                 size=size).reshape(sparse_shape)

    @utils.lru_cache
    def _spatial_concat(self, *input_tensors):
        input_tensors = [self._spatial_reshape(t) for t in input_tensors]
        reducible_inputs = tf.concat(input_tensors, axis=self._channel_axis)
        return reducible_inputs

    def _spatial_reshape(self, t, forward=True):
        """Reshapes a Tensor ``t``` to one that represents the spatial dimensions.

        Args:
            t (Tensor): The ``Tensor`` to reshape.
            forward (bool): Whether to reshape for forward inference. If True, reshapes to
                ``[batch, rows, cols, 1, input_channels]``. Otherwise, reshapes to
                ``[batch, rows, cols, output_channels, input_channels]``.
        Returns:
             A reshaped ``Tensor``.
        """
        non_batch_dim_size = self._non_batch_dim_prod(t)
        if forward:
            input_channels = non_batch_dim_size // np.prod(
                self._spatial_dim_sizes)
            return tf.reshape(t, [-1] + self._spatial_dim_sizes +
                              [input_channels])
        return tf.reshape(t, [-1] + self._spatial_dim_sizes +
                          [self._num_channels])

    def _non_batch_dim_prod(self, t):
        """Computes the product of the non-batch dimensions to be used for reshaping purposes.

        Args:
            t (Tensor): A ``Tensor`` for which to compute the product.

        Returns:
            An ``int``: product of non-batch dimensions.
        """
        non_batch_dim_size = np.prod([
            ds for i, ds in enumerate(t.shape.as_list())
            if i != self._batch_axis
        ])
        return int(non_batch_dim_size)

    def _num_channels_per_input(self):
        """Returns a list of number of input channels for each value Input.

        Returns:
            A list of ints containing the number of channels.
        """
        input_sizes = self.get_input_sizes()
        return [
            int(s // np.prod(self._spatial_dim_sizes)) for s in input_sizes
        ]

    def _num_input_channels(self):
        return sum(self._num_channels_per_input())

    def _compute_out_size_spatial(self, *input_out_sizes):
        kernel_size0, kernel_size1 = self._effective_kernel_size()

        pad_left, pad_right, pad_top, pad_bottom = self.pad_sizes()

        rows_post_pad = pad_top + pad_bottom + self._spatial_dim_sizes[0]
        cols_post_pad = pad_left + pad_right + self._spatial_dim_sizes[1]
        rows_post_pad -= kernel_size0 - 1
        cols_post_pad -= kernel_size1 - 1
        out_rows = int(np.ceil(rows_post_pad / self._strides[0]))
        out_cols = int(np.ceil(cols_post_pad / self._strides[1]))
        return int(out_rows), int(out_cols), self._num_channels

    def _compute_out_size(self, *input_out_sizes):
        return int(np.prod(self._compute_out_size_spatial(*input_out_sizes)))

    @property
    def output_shape_spatial(self):
        return self._compute_out_size_spatial()

    @property
    def same_padding(self):
        return self._padding.lower() == "same"

    @property
    def valid_padding(self):
        return self._padding.lower() == "valid"

    def _effective_kernel_size(self):
        # See https://www.tensorflow.org/api_docs/python/tf/nn/convolution
        return [(self._kernel_size[0] - 1) * self._dilation_rate[0] + 1,
                (self._kernel_size[1] - 1) * self._dilation_rate[1] + 1]

    def pad_sizes(self):
        pad_top_explicit, pad_bottom_explicit, pad_left_explicit, pad_right_explicit = \
            self._explicit_pad_sizes()
        if self.valid_padding:
            # No padding
            if pad_top_explicit == pad_bottom_explicit \
                    == pad_left_explicit == pad_right_explicit == 0:
                return 0, 0, 0, 0
            return pad_left_explicit, pad_right_explicit, pad_top_explicit, pad_bottom_explicit

        # See https://www.tensorflow.org/api_guides/python/nn#Convolution
        filter_height, filter_width = self._effective_kernel_size()
        if self._spatial_dim_sizes[0] % self._strides[0] == 0:
            pad_along_height = max(filter_height - self._strides[0], 0)
        else:
            pad_along_height = max(
                filter_height -
                (self._spatial_dim_sizes[0] % self._strides[0]), 0)
        if self._spatial_dim_sizes[1] % self._strides[1] == 0:
            pad_along_width = max(filter_width - self._strides[1], 0)
        else:
            pad_along_width = max(
                filter_width - (self._spatial_dim_sizes[1] % self._strides[1]),
                0)

        pad_top = pad_along_height // 2
        pad_bottom = pad_along_height - pad_top
        pad_left = pad_along_width // 2
        pad_right = pad_along_width - pad_left
        return (pad_left + pad_left_explicit, pad_right + pad_right_explicit,
                pad_top + pad_top_explicit, pad_bottom + pad_bottom_explicit)

    def deserialize_inputs(self, data, nodes_by_name):
        pass

    @property
    def inputs(self):
        return self._values

    def serialize(self):
        pass

    def deserialize(self, data):
        pass

    @property
    def _const_out_size(self):
        return True
Beispiel #14
0
class ConvProducts(OpNode):
    """A container representing convolutional products in an SPN.

    Args:
        *values (input_like): Inputs providing input values to this container.
            See :meth:`~libspn.Input.as_input` for possible values.
        num_channels (int): Number of channels modeled by this node. This parameter is optional.
            If ``None``, the layer will attempt to generate all possible permutations of channels
            under a patch as long as it is under ``num_channels_max``.
        padding (str): Type of padding used. Can be either, 'full', 'valid' or 'wicker_top'.
            For building Wicker CSPNs, 'full' padding is necessary in all but the very last
            ConvProducts node. The last ConvProducts node should take the 'wicker_top' padding algorithm
        dilation_rate (int or tuple of ints): Dilation rate of the convolution.
        strides (int or tuple of ints): Strides used for the convolution.
        sparse_connections (numpy.ndarray): Sparse representation of connections
            [height, width, num_out_channels]
        dense_connections (numpy.ndarray): Dense representation of connections
            ([height, width, num_in_channels, num_out_channels])
        spatial_dim_sizes (tuple or list of ints): Dim sizes of spatial dimensions (height and width)
        num_channels_max (int): The maximum number of channels when automatically generating
            permutations.
        name (str): Name of the container.

    Attributes:
        inference_type(InferenceType): Flag indicating the preferred inference
                                       type for this container that will be used
                                       during value calculation and learning.
                                       Can be changed at any time and will be
                                       used during the next inference/learning
                                       op generation.
    """

    logger = get_logger()

    def __init__(self,
                 *values,
                 num_channels=None,
                 padding='valid',
                 dilation_rate=1,
                 strides=2,
                 kernel_size=2,
                 inference_type=InferenceType.MARGINAL,
                 name="ConvProducts",
                 sparse_connections=None,
                 dense_connections=None,
                 spatial_dim_sizes=None,
                 num_channels_max=512):
        self._batch_axis = 0
        self._channel_axis = 3
        super().__init__(inference_type=inference_type, name=name)
        self.set_values(*values)

        num_channels = min(num_channels or num_channels_max, num_channels_max)

        self._spatial_dim_sizes = self._values[0].node.output_shape_spatial[:2] if \
            isinstance(self._values[0].node, (SpatialSums, ConvProducts)) else spatial_dim_sizes
        if self._spatial_dim_sizes is None:
            raise StructureError("{}: if no spatial".format(self))
        if isinstance(self._spatial_dim_sizes, tuple):
            self._spatial_dim_sizes = list(self._spatial_dim_sizes)
        self._padding = padding
        self._dilation_rate = [dilation_rate] * 2 \
            if isinstance(dilation_rate, int) else list(dilation_rate)
        self._strides = [strides] * 2 \
            if isinstance(strides, int) else list(strides)
        self._num_channels = num_channels
        self._kernel_size = [kernel_size] * 2 if isinstance(kernel_size, int) \
            else list(kernel_size)

        # Generate connections if needed
        if sparse_connections is not None:
            if dense_connections is not None:
                raise ValueError(
                    "{}: Must provide either spare connections or dense connections, "
                    "not both.".format(self))
            self._sparse_connections = sparse_connections
            self._dense_connections = self.sparse_kernels_to_onehot(
                sparse_connections)
        elif dense_connections is not None:
            self._dense_connections = dense_connections
            self._sparse_connections = self.onehot_kernels_to_sparse(
                dense_connections)
        else:
            self._sparse_connections = self.generate_sparse_kernels(
                num_channels)
            self._dense_connections = self.sparse_kernels_to_onehot(
                self._sparse_connections)

        self._scope_mask = None

    def set_values(self, *values):
        """Set the inputs providing input values to this node. If no arguments
        are given, all existing value inputs get disconnected.

        Args:
            *values (input_like): Inputs providing input values to this node.
                See :meth:`~libspn.Input.as_input` for possible values.
        """
        if len(values) == 0:
            raise StructureError(
                "{}: must be initialized with at least one input.".format(
                    self))
        self._values = self._parse_inputs(*values)

    def sparse_kernels_to_onehot(self, sparse):
        """Converts an index-based representation of sparse kernels to a dense onehot
        representation.

        Args:
            sparse (numpy.ndarray): A sparse kernel representation of shape
                [rows, cols, output_channel] containing the indices for which the kernel equals 1.

        Returns:
            A onehot representation of the same kernel with shape
            [rows, cols, input_channel, output_channel].
        """
        # Sparse has shape [rows, cols, out_channels]
        # Dense will have shape [rows, cols, in_channels, out_channels]
        num_input_channels = self._num_input_channels()
        batch_grid_prod = int(np.prod(self._kernel_size +
                                      [self._num_channels]))
        dense = np.zeros([batch_grid_prod, num_input_channels])

        dense[np.arange(batch_grid_prod), sparse.ravel()] = 1.0
        dense = dense.reshape(self._kernel_size +
                              [self._num_channels, num_input_channels])
        # [rows, cols, out_channels, in_channels] to [rows, cols, in_channels, out_channels]
        return np.transpose(dense,
                            (0, 1, 3, 2)).astype(conf.dtype.as_numpy_dtype())

    def onehot_kernels_to_sparse(self, dense):
        """Converts a dense kernel to an index-based representation.

        Args:
            dense (numpy.ndarray): The dense one-hot representation of the kernels with shape
                [rows, cols, input_channel, output_channel]

        Returns:
            An index-based representation of the kernels of shape [rows, cols, output_channel].
        """
        return np.argmax(dense, axis=2)

    def generate_sparse_kernels(self, num_channels):
        """Generates sparse kernels kernels. These kernels only contain '1' on a single channel
        per row and column. The remaining values are all zero. This method returns the sparse
        representation, containing only the indices for which the kernels are 1 along the input
        channel axis.

        Args:
            num_channels (int): The number of channels. In case the number of channels given is
                larger than the number of possible one-hot assignments, a warning is given and the
                number of channels is set accordingly before generating the connections.
        Returns:
            A `numpy.ndarray` containing the 'sparse' representation of the kernels with shape
            `[row, column, channel]`, containing the indices of the input channel for which the
            kernel is 1.
        """
        num_input_channels = self._num_input_channels()
        kernel_surface = int(np.prod(self._kernel_size))
        total_possibilities = num_input_channels**kernel_surface
        if num_channels >= total_possibilities:
            if num_channels > total_possibilities:
                self.logger.warn(
                    "Number of channels exceeds total number of combinations.")
                self._num_channels = total_possibilities
            p = np.arange(total_possibilities)
            kernel_cells = []
            for _ in range(kernel_surface):
                kernel_cells.append(p % num_input_channels)
                p //= num_input_channels
            return np.stack(kernel_cells,
                            axis=0).reshape(self._kernel_size +
                                            [total_possibilities])

        if self._num_channels >= num_input_channels:
            kernel_cells = []
            for _ in range(kernel_surface):
                ind = np.arange(self._num_channels) % num_input_channels
                np.random.shuffle(ind)
                kernel_cells.append(ind)
            return np.asarray(kernel_cells).reshape(self._kernel_size +
                                                    [self._num_channels])

        sparse_shape = self._kernel_size + [num_channels]
        size = int(np.prod(sparse_shape))
        return np.random.randint(num_input_channels,
                                 size=size).reshape(sparse_shape)

    @utils.lru_cache
    def _spatial_concat(self, *input_tensors):
        """Concatenates input tensors spatially. Makes sure to reshape them before.

        Args:
            input_tensors (tuple): A tuple of `Tensor`s to concatenate along the channel axis.

        Returns:
            The concatenated tensor.
        """
        input_tensors = [self._spatial_reshape(t) for t in input_tensors]
        return tf.concat(input_tensors, axis=self._channel_axis)

    def _spatial_reshape(self, t, forward=True):
        """Reshapes a Tensor ``t``` to one that represents the spatial dimensions.

        Args:
            t (Tensor): The ``Tensor`` to reshape.
            forward (bool): Whether to reshape for forward inference. If True, reshapes to
                ``[batch, rows, cols, 1, input_channels]``. Otherwise, reshapes to
                ``[batch, rows, cols, output_channels, input_channels]``.
        Returns:
             A reshaped ``Tensor``.
        """
        non_batch_dim_size = self._non_batch_dim_prod(t)
        if forward:
            input_channels = non_batch_dim_size // np.prod(
                self._spatial_dim_sizes)
            return tf.reshape(t, [-1] + self._spatial_dim_sizes +
                              [input_channels])
        return tf.reshape(t, [-1] + self._spatial_dim_sizes +
                          [self._num_channels])

    def _non_batch_dim_prod(self, t):
        """Computes the product of the non-batch dimensions to be used for reshaping purposes.

        Args:
            t (Tensor): A ``Tensor`` for which to compute the product.

        Returns:
            An ``int``: product of non-batch dimensions.
        """
        non_batch_dim_size = np.prod([
            ds for i, ds in enumerate(t.shape.as_list())
            if i != self._batch_axis
        ])
        return int(non_batch_dim_size)

    def _set_scope_mask(self, t):
        self._scope_mask = t

    def _num_channels_per_input(self):
        """Returns a list of number of input channels for each value Input.

        Returns:
            A list of ints containing the number of channels.
        """
        input_sizes = self.get_input_sizes()
        return [
            int(s // np.prod(self._spatial_dim_sizes)) for s in input_sizes
        ]

    def _num_input_channels(self):
        """Computes the total number of input channels.

        Returns:
            An int indicating the number of input channels for the convolution operation.
        """
        return sum(self._num_channels_per_input())

    def _compute_out_size_spatial(self, *input_out_sizes):
        """Computes spatial output shape.

        Returns:
            A tuple with (num_rows, num_cols, num_channels).
        """
        kernel_size0, kernel_size1 = self._effective_kernel_size()

        pad_left, pad_right, pad_top, pad_bottom = self.pad_sizes()

        rows_post_pad = pad_top + pad_bottom + self._spatial_dim_sizes[
            0] - kernel_size0 + 1
        cols_post_pad = pad_left + pad_right + self._spatial_dim_sizes[
            1] - kernel_size1 + 1
        out_rows = int(np.ceil(rows_post_pad / self._strides[0]))
        out_cols = int(np.ceil(cols_post_pad / self._strides[1]))
        return int(out_rows), int(out_cols), self._num_channels

    def _compute_out_size(self, *input_out_sizes):
        return int(np.prod(self._compute_out_size_spatial(*input_out_sizes)))

    @property
    def output_shape_spatial(self):
        """tuple: The spatial shape of this node, formatted as (rows, columns, channels). """
        return self._compute_out_size_spatial()

    @property
    def same_padding(self):
        """bool: Whether the padding algorithm is set to SAME. """
        return self._padding.lower() == "same"

    @property
    def valid_padding(self):
        """bool: Whether the padding algrorithm is set to VALID """
        return self._padding.lower() == "valid"

    def _effective_kernel_size(self):
        """Computes the 'effective' kernel size by also taking into account the dilation rate.

        Returns:
            tuple: A tuple with (num_kernel_rows, num_kernel_cols)
        """
        return [(self._kernel_size[0] - 1) * self._dilation_rate[0] + 1,
                (self._kernel_size[1] - 1) * self._dilation_rate[1] + 1]

    @utils.lru_cache
    def _compute_log_value(self, *input_tensors):
        # Concatenate along channel axis
        concat_inp = self._prepare_convolutional_processing(*input_tensors)

        # Convolve
        # TODO, this the quickest workaround for TensorFlow's apparent optimization whenever
        # part of the kernel computation involves a -inf:
        concat_inp = tf.where(tf.is_inf(concat_inp),
                              tf.fill(tf.shape(concat_inp), value=-1e20),
                              concat_inp)

        conv_out = tf.nn.conv2d(input=concat_inp,
                                filter=self._dense_connections,
                                padding='VALID',
                                strides=[1] + self._strides + [1],
                                dilations=[1] + self._dilation_rate + [1],
                                data_format='NHWC')
        return self._flatten(conv_out)

    @utils.lru_cache
    def _transpose_channel_last_to_first(self, t):
        return tf.transpose(t, (0, 3, 1, 2))

    @utils.lru_cache
    def _transpose_channel_first_to_last(self, t):
        return tf.transpose(t, (0, 2, 3, 1))

    def _compute_log_mpe_value(self, *input_tensors):
        return self._compute_log_value(*input_tensors)

    def _compute_mpe_path_common(self, counts, *input_values):
        if not self._values:
            raise StructureError("{} is missing input values.".format(self))
        # Concatenate inputs along channel axis, should already be done during forward pass
        inp_concat = self._prepare_convolutional_processing(*input_values)
        spatial_counts = tf.reshape(counts, (-1, ) + self.output_shape_spatial)

        input_counts = tf.nn.conv2d_backprop_input(
            input_sizes=tf.shape(inp_concat),
            filter=self._dense_connections,
            out_backprop=spatial_counts,
            strides=[1] + self._strides + [1],
            padding='VALID',
            dilations=[1] + self._dilation_rate + [1],
            data_format="NHWC")

        # In case we have explicitly padded the tensor before forward convolution, we should
        # slice the counts now
        pad_left, pad_right, pad_bottom, pad_top = self.pad_sizes()
        if not any([pad_bottom, pad_left, pad_right, pad_top]):
            return self._split_to_children(input_counts)
        return self._split_to_children(input_counts[:, pad_top:-pad_bottom,
                                                    pad_left:-pad_right, :])

    @utils.lru_cache
    def _prepare_convolutional_processing(self, *input_values):
        inp_concat = self._spatial_concat(*input_values)
        return self._maybe_pad(inp_concat)

    def _maybe_pad(self, x):
        pad_left, pad_right, pad_top, pad_bottom = self.pad_sizes()
        if not any([pad_left, pad_right, pad_top, pad_bottom]):
            # If all pad sizes are 0, just return x
            return x
        # Pad x
        paddings = [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right],
                    [0, 0]]
        return tf.pad(x, paddings=paddings, mode="CONSTANT", constant_values=0)

    def _compute_log_mpe_path(self,
                              counts,
                              *input_values,
                              use_unweighted=False,
                              sample=False,
                              sample_prob=None):
        return self._compute_mpe_path_common(counts, *input_values)

    @utils.lru_cache
    def _split_to_children(self, x):
        if len(self.inputs) == 1:
            return [self._flatten(x)]
        x_split = tf.split(x,
                           num_or_size_splits=self._num_channels_per_input(),
                           axis=self._channel_axis)
        return [self._flatten(t) for t in x_split]

    @utils.lru_cache
    def _flatten(self, t):
        """Flattens a Tensor ``t`` so that the resulting shape is [batch, non_batch]

        Args:
            t (Tensor): A ``Tensor```to flatten

        Returns:
            A flattened ``Tensor``.
        """
        if self._batch_axis != 0:
            raise NotImplementedError(
                "{}: Cannot flatten if batch axis isn't equal to zero.".format(
                    self))
        non_batch_dim_size = self._non_batch_dim_prod(t)
        return tf.reshape(t, (-1, non_batch_dim_size))

    @utils.docinherit(OpNode)
    def _compute_scope(self, *value_scopes, check_valid=False):
        flat_value_scopes = self._gather_input_scopes(*value_scopes)

        value_scopes_grid = [
            np.asarray(vs).reshape(self._spatial_dim_sizes + [-1])
            for vs in flat_value_scopes
        ]
        value_scopes_concat = np.concatenate(value_scopes_grid, axis=2)

        dilate = self._dilation_rate
        kernel_size = self._kernel_size
        grid_dims = self._spatial_dim_sizes
        strides = self._strides
        input_channels = self._num_input_channels()

        pad_left, pad_right, pad_top, pad_bottom = self.pad_sizes()
        if any(p != 0 for p in [pad_right, pad_left, pad_top, pad_bottom]):
            padded_value_scopes_concat = np.empty(
                (pad_top + grid_dims[0] + pad_bottom,
                 pad_left + grid_dims[1] + pad_right, input_channels),
                dtype=Scope)
            # Pad with empty scopes
            empty_scope = Scope.merge_scopes([])
            padded_value_scopes_concat[:, :pad_left] = empty_scope
            padded_value_scopes_concat[:pad_top, :] = empty_scope
            padded_value_scopes_concat[-pad_bottom:, :] = empty_scope
            padded_value_scopes_concat[:, -pad_right:] = empty_scope
            padded_value_scopes_concat[
                pad_top:pad_top + value_scopes_concat.shape[0],
                pad_left:pad_left +
                value_scopes_concat.shape[0]] = value_scopes_concat
            value_scopes_concat = padded_value_scopes_concat

        scope_list = []
        kernel_size0, kernel_size1 = self._effective_kernel_size()
        # Reset grid dims as we might have padded the scopes
        grid_dims = value_scopes_concat.shape[:2]
        for row in range(0, grid_dims[0] - kernel_size0 + 1, strides[0]):
            row_indices = list(range(row, row + kernel_size0, dilate[0]))
            for col in range(0, grid_dims[1] - kernel_size1 + 1, strides[1]):
                col_indices = list(range(col, col + kernel_size1, dilate[1]))
                for channel in range(self._num_channels):
                    single_scope = []
                    for im_row, kernel_row in zip(row_indices,
                                                  range(kernel_size[0])):
                        for im_col, kernel_col in zip(col_indices,
                                                      range(kernel_size[1])):
                            single_scope.append(value_scopes_concat[
                                im_row, im_col,
                                self._sparse_connections[kernel_row,
                                                         kernel_col, channel]])
                    # Ensure valid
                    if check_valid:
                        for sc1, sc2 in itertools.combinations(
                                single_scope, 2):
                            if sc1 & sc2:
                                # Invalid if intersection not empty
                                self.logger.warn(
                                    "{} is not decomposable".format(self))
                                return None
                    scope_list.append(Scope.merge_scopes(single_scope))

        return scope_list

    def pad_sizes(self):
        """Determines the pad sizes. Possibly adds up explicit padding and padding through SAME
        padding algorithm of `tf.nn.convolution`.

        Returns:
            A tuple of left, right, top and bottom padding sizes.
        """
        if self._padding == 'valid':
            return 0, 0, 0, 0
        if self._padding == 'full':
            kernel_height, kernel_width = self._effective_kernel_size()
            pad_top = pad_bottom = kernel_height - 1
            pad_left = pad_right = kernel_width - 1
            return pad_left, pad_right, pad_top, pad_bottom
        if self._padding == 'wicker_top':
            kernel_height, kernel_width = self._effective_kernel_size()
            pad_top = (kernel_height - 1) * 2 - self._spatial_dim_sizes[0]
            pad_left = (kernel_width - 1) * 2 - self._spatial_dim_sizes[1]
            return 0, pad_left, 0, pad_top
        raise ValueError(
            "{}: invalid padding algorithm. Use 'valid', 'full' or 'wicker_top', got '{}'"
            .format(self, self._padding))

    @utils.docinherit(OpNode)
    def _compute_valid(self, *value_scopes):
        return self._compute_scope(*value_scopes, check_valid=True)

    def deserialize_inputs(self, data, nodes_by_name):
        self._values = tuple(
            Input(nodes_by_name[nn], i) for nn, i in data['values'])

    @property
    def inputs(self):
        return self._values

    def serialize(self):
        data = super().serialize()
        data['padding'] = self._padding
        data['spatial_dim_sizes'] = self._spatial_dim_sizes
        data['sparse_connections'] = self._sparse_connections
        return data

    def deserialize(self, data):
        super().deserialize(data)
        self._padding = data['padding']
        self._spatial_dim_sizes = data['spatial_dim_sizes']
        self._sparse_connections = data['sparse_connections']
        self._dense_connections = self.sparse_kernels_to_onehot(
            self._sparse_connections)

    @property
    def _const_out_size(self):
        return True
Beispiel #15
0
class BaseSum(OpNode, abc.ABC):

    logger = get_logger()
    info = logger.info
    """An abstract node representing sums in an SPN.

    Args:
        *values (input_like): Inputs providing input values to this node.
            See :meth:`~libspn.Input.as_input` for possible values.
        num_sums (int): Number of Sum ops modelled by this node.
        sum_sizes (list): A list of ints corresponding to the sizes of each sum. If both num_sums
                          and sum_sizes are given, we should have len(sum_sizes) == num_sums.
        batch_axis (int): The index of the batch axis.
        op_axis (int): The index of the op axis that contains the individual sums being modeled.
        reduce_axis (int): The axis over which to perform summing (or max for MPE)
        weights (input_like): Input providing weights node to this sum node.
            See :meth:`~libspn.Input.as_input` for possible values. If set
            to ``None``, the input is disconnected.
        ivs (input_like): Input providing IVs of an explicit latent variable
            associated with this sum node. See :meth:`~libspn.Input.as_input`
            for possible values. If set to ``None``, the input is disconnected.
        name (str): Name of the node.

    Attributes:
        inference_type(InferenceType): Flag indicating the preferred inference
                                       type for this node that will be used
                                       during value calculation and learning.
                                       Can be changed at any time and will be
                                       used during the next inference/learning
                                       op generation.
    """
    def __init__(self,
                 *values,
                 num_sums,
                 weights=None,
                 ivs=None,
                 sum_sizes=None,
                 inference_type=InferenceType.MARGINAL,
                 batch_axis=0,
                 op_axis=1,
                 reduce_axis=2,
                 masked=False,
                 sample_prob=None,
                 dropconnect_keep_prob=None,
                 name="Sum"):
        super().__init__(inference_type=inference_type, name=name)

        self.set_values(*values)
        self.set_weights(weights)
        self.set_ivs(ivs)

        # Initialize the number of sums and the sum sizes
        self._reset_sum_sizes(num_sums=num_sums, sum_sizes=sum_sizes)

        # Set the axes
        self._batch_axis = batch_axis
        self._op_axis = op_axis
        self._reduce_axis = reduce_axis

        # Set whether this instance is masked (e.g. SumsLayer)
        self._masked = masked

        # Set the sampling probability and sampling type
        self._sample_prob = sample_prob

        # Set dropconnect keep probability
        self._dropconnect_keep_prob = dropconnect_keep_prob

    def _get_sum_sizes(self, num_sums):
        """Computes a list of sum sizes given the number of sums and the currently attached input
        nodes.

        Args:
            num_sums (int): The number of sums modeled by this node.
        Returns:
            A list of sum sizes, where the i-th element corresponds to the size of the i-th sum.
        """
        input_sizes = self.get_input_sizes()
        num_values = sum(input_sizes[2:])  # Skip ivs, weights
        return num_sums * [num_values]

    def _build_mask(self):
        """Constructs mask that could be used to cancel out 'columns' that are padded as a result of
        varying reduction sizes. Returns a Boolean mask.

        Returns:
            By default the sums are not masked, returns ``None``
        """
        return None

    @utils.docinherit(OpNode)
    def serialize(self):
        data = super().serialize()
        data['values'] = [(i.node.name, i.indices) for i in self._values]
        if self._weights:
            data['weights'] = (self._weights.node.name, self._weights.indices)
        if self._ivs:
            data['ivs'] = (self._ivs.node.name, self._ivs.indices)
        data['num_sums'] = self._num_sums
        data['sum_sizes'] = self._sum_sizes
        data['op_axis'] = self._op_axis
        data['reduce_axis'] = self._reduce_axis
        data['batch_axis'] = self._batch_axis
        data['dropconnect_keep_prob'] = self._dropconnect_keep_prob
        data['sample_prob'] = self._sample_prob
        return data

    @utils.docinherit(OpNode)
    def deserialize(self, data):
        super().deserialize(data)
        self.set_values()
        self.set_weights()
        self.set_ivs()
        self._dropconnect_keep_prob = data['dropconnect_keep_prob']
        self._sample_prob = data['sample_prob']
        self._num_sums = data['num_sums']
        self._sum_sizes = data['sum_sizes']
        self._max_sum_size = max(self._sum_sizes) if self._sum_sizes else 0
        self._batch_axis = data['batch_axis']
        self._op_axis = data['op_axis']
        self._reduce_axis = data['reduce_axis']

    def disconnect_inputs(self):
        self._ivs = self._weights = self._values = None

    @utils.docinherit(OpNode)
    def deserialize_inputs(self, data, nodes_by_name):
        super().deserialize_inputs(data, nodes_by_name)
        self._values = tuple(
            Input(nodes_by_name[nn], i) for nn, i in data['values'])
        weights = data.get('weights', None)
        if weights:
            self._weights = Input(nodes_by_name[weights[0]], weights[1])
        ivs = data.get('ivs', None)
        if ivs:
            self._ivs = Input(nodes_by_name[ivs[0]], ivs[1])

    @property
    @utils.docinherit(OpNode)
    def inputs(self):
        return (self._weights, self._ivs) + self._values

    @property
    def dropconnect_keep_prob(self):
        return self._dropconnect_keep_prob

    def set_dropconnect_keep_prob(self, p):
        self._dropconnect_keep_prob = p

    @property
    def weights(self):
        """Input: Weights input."""
        return self._weights

    def set_weights(self, weights=None):
        """Set the weights input.

        Args:
            weights (input_like): Input providing weights node to this sum node.
                See :meth:`~libspn.Input.as_input` for possible values. If set
                to ``None``, the input is disconnected.
        """
        weights, = self._parse_inputs(weights)
        if weights and not isinstance(weights.node, Weights):
            raise StructureError("%s is not Weights" % weights.node)
        self._weights = weights

    def _reset_sum_sizes(self, num_sums=None, sum_sizes=None):
        """Resets the sizes and number of sums. If number of sums is specified, it will take that
        value, otherwise it will take the value that is already set. If sum_sizes is specified
        it will take that value, otherwise it will infer that using
        :meth:`~libspn.BaseSum._get_sum_sizes`. Finally, it also sets the maximum sum size.

        Args:
            num_sums (int): Number of sums modeled by this ``Node``.
            sum_sizes (int): A list of sum sizes with as many ``int``s as there are sums modeled.
        """
        self._num_sums = num_sums or self._num_sums
        self._sum_sizes = sum_sizes or self._get_sum_sizes(self._num_sums)
        self._max_sum_size = max(self._sum_sizes) if self._sum_sizes else 0

    @property
    def ivs(self):
        """Input: IVs input."""
        return self._ivs

    def set_ivs(self, ivs=None):
        """Set the IVs input.

        ivs (input_like): Input providing IVs of an explicit latent variable
            associated with this sum node. See :meth:`~libspn.Input.as_input`
            for possible values. If set to ``None``, the input is disconnected.
        """
        self._ivs, = self._parse_inputs(ivs)

    @property
    def values(self):
        """list of Input: List of value inputs."""
        return self._values

    @property
    def num_sums(self):
        """int: The number of sums modeled by this node. """
        return self._num_sums

    @property
    def sum_sizes(self):
        """list of int: A list of the sum sizes. """
        return self._sum_sizes

    def set_values(self, *values):
        """Set the inputs providing input values to this node. If no arguments
        are given, all existing value inputs get disconnected.

        Args:
            *values (input_like): Inputs providing input values to this node.
                See :meth:`~libspn.Input.as_input` for possible values.
        """
        self._values = self._parse_inputs(*values)

    def add_values(self, *values):
        """Add more inputs providing input values to this node.

        Args:
            *values (input_like): Inputs providing input values to this node.
                See :meth:`~libspn.Input.as_input` for possible values.
        """
        self._values += self._parse_inputs(*values)
        self._reset_sum_sizes()

    def generate_weights(self,
                         initializer=tf.initializers.constant(1.0),
                         trainable=True,
                         input_sizes=None,
                         log=False,
                         name=None):
        """Generate a weights node matching this sum node and connect it to
        this sum.

        The function calculates the number of weights based on the number
        of input values of this sum. Therefore, weights should be generated
        once all inputs are added to this node.

        Args:
            initializer: Initial value of the weights.
            trainable (bool): See :class:`~libspn.Weights`.
            input_sizes (list of int): Pre-computed sizes of each input of
                this node.  If given, this function will not traverse the graph
                to discover the sizes.
            log (bool): If "True", the weights are represented in log space.
            name (str): Name of the weighs node. If ``None`` use the name of the
                        sum + ``_Weights``.

        Return:
            Weights: Generated weights node.
        """
        if not self._values:
            raise StructureError("%s is missing input values" % self)
        if name is None:
            name = self._name + "_Weights"
        # Count all input values
        if input_sizes:
            num_values = sum(input_sizes[2:])  # Skip ivs, weights
        else:
            num_values = max(self._sum_sizes)
        # Generate weights
        weights = Weights(initializer=initializer,
                          num_weights=num_values,
                          num_sums=self._num_sums,
                          log=log,
                          trainable=trainable,
                          name=name)
        self.set_weights(weights)
        return weights

    def generate_ivs(self, feed=None, name=None):
        """Generate an IVs node matching this sum node and connect it to
        this sum.

        IVs should be generated once all inputs are added to this node,
        otherwise the number of IVs will be incorrect.

        Args:
            feed (Tensor): See :class:`~libspn.IVs`.
            name (str): Name of the IVs node. If ``None`` use the name of the
                        sum + ``_IVs``.

        Return:
            IVs: Generated IVs node.
        """
        if not self._values:
            raise StructureError("%s is missing input values" % self)
        if name is None:
            name = self._name + "_IVs"
        ivs = IVs(feed=feed,
                  num_vars=self._num_sums,
                  num_vals=self._max_sum_size,
                  name=name)
        self.set_ivs(ivs)
        return ivs

    @utils.lru_cache
    def _compute_reducible(self,
                           w_tensor,
                           ivs_tensor,
                           *input_tensors,
                           weighted=True,
                           dropconnect_keep_prob=None):
        """Computes a reducible ``Tensor`` so that reducing it over the last axis can be used for
        marginal inference, MPE inference and MPE path computation.

        Args:
            w_tensor (Tensor): A ``Tensor`` with the value of the weights of shape
                ``[num_sums, max_sum_size]``
            ivs_tensor (Tensor): A ``Tensor`` with the value of the IVs corresponding to this node
                of shape ``[batch, num_sums * max_sum_size]``.
            input_tensors (tuple): A ``tuple`` of ``Tensors``s with the values of the children of
                this node.
            weighted (bool): Whether to apply the weights to the reducible values if possible.
            dropconnect_keep_prob (Tensor or float): A scalar ``Tensor`` or float that holds the
                dropconnect keep probability. By default it is None, in which case no dropconnect
                is being used.

        Returns:
            A ``Tensor`` of shape ``[batch, num_sums, max_sum_size]`` that can be used for computing
            marginal inference, MPE inference, gradients or MPE paths.
        """
        if not self._values:
            raise StructureError("%s is missing input values" % self)
        if not self._weights:
            raise StructureError("%s is missing weights" % self)

        # Prepare tensors for component-wise application of weights and IVs
        w_tensor, ivs_tensor, reducible = self._prepare_component_wise_processing(
            w_tensor, ivs_tensor, *input_tensors, zero_prob_val=-float('inf'))

        # Apply latent IVs
        if self._ivs:
            reducible = utils.cwise_add(reducible, ivs_tensor)

        # Apply weights
        if weighted:
            # Maybe apply dropconnect
            dropconnect_keep_prob = utils.maybe_first(
                dropconnect_keep_prob, self._dropconnect_keep_prob)

            if dropconnect_keep_prob is not None and dropconnect_keep_prob != 1.0:
                if self._ivs:
                    self.logger.warn(
                        "Using dropconnect and latent IVs simultaneously. "
                        "This might result in zero probabilities throughout and unpredictable "
                        "behavior of learning. Therefore, dropconnect is turned off for node {}."
                        .format(self))
                else:
                    mask = self._create_dropout_mask(dropconnect_keep_prob,
                                                     tf.shape(reducible),
                                                     log=True)
                    w_tensor = utils.cwise_add(w_tensor, mask)
                    if conf.renormalize_dropconnect:
                        w_tensor = tf.nn.log_softmax(w_tensor, axis=-1)
            reducible = utils.cwise_add(reducible, w_tensor)

        return reducible

    @utils.docinherit(OpNode)
    @utils.lru_cache
    def _compute_out_size(self, *input_out_sizes):
        return self._num_sums

    @utils.docinherit(OpNode)
    @utils.lru_cache
    def _compute_log_value(self,
                           w_tensor,
                           ivs_tensor,
                           *value_tensors,
                           dropconnect_keep_prob=None):

        # Defines soft-gradient for the log value
        def custom_grad(grad):
            # Use the _compute_log_gradient method to compute the gradient w.r.t. to the
            # inputs of this node.
            scattered_grads = self._compute_log_gradient(
                grad,
                w_tensor,
                ivs_tensor,
                *value_tensors,
                accumulate_weights_batch=True,
                dropconnect_keep_prob=dropconnect_keep_prob)

            return [sg for sg in scattered_grads if sg is not None]

        # Wrap the log value with its custom gradient
        @tf.custom_gradient
        def _log_value(*input_tensors):
            # First reduce over last axis
            val = self._reduce_marginal_inference_log(
                self._compute_reducible(
                    w_tensor,
                    ivs_tensor,
                    *value_tensors,
                    weighted=True,
                    dropconnect_keep_prob=dropconnect_keep_prob))

            return val, custom_grad

        # if conf.custom_gradient:
        #     return _log_value(*self._get_differentiable_inputs(
        #         w_tensor, ivs_tensor, *value_tensors))
        # else:
        return self._reduce_marginal_inference_log(
            self._compute_reducible(w_tensor,
                                    ivs_tensor,
                                    *value_tensors,
                                    weighted=True))

    def _get_differentiable_inputs(self, w_tensor, ivs_tensor, *value_tensors):
        """Selects the tensors to include for a tf.custom_gradient when computing the log-value.

        Args:
            w_tensor (Tensor): A ``Tensor`` of shape [num_sums, max_sum_size] with the value of
                               the weights corresponding to this node.
            ivs_tensor (Tensor): A ``Tensor`` of shape [batch, num_sums, max_sum_size] with the
                                 value of the IVs corresponding to this node.
`
        """
        return [w_tensor
                ] + ([ivs_tensor] if self._ivs else []) + list(value_tensors)

    @utils.docinherit(OpNode)
    @utils.lru_cache
    def _compute_log_mpe_value(self,
                               w_tensor,
                               ivs_tensor,
                               *value_tensors,
                               dropconnect_keep_prob=None):

        # Defines hard-gradient for the log-mpe
        def custom_grad(grad):
            scattered_grads = self._compute_log_mpe_path(
                grad, w_tensor, ivs_tensor, *value_tensors)

            return [sg for sg in scattered_grads if sg is not None]

        # Wrap the log value with its custom gradient
        @tf.custom_gradient
        def _log_mpe_value(*input_tensors):
            val = self._reduce_mpe_inference_log(
                self._compute_reducible(
                    w_tensor,
                    ivs_tensor,
                    *value_tensors,
                    weighted=True,
                    dropconnect_keep_prob=dropconnect_keep_prob))
            return val, custom_grad

        if conf.custom_gradient:
            return _log_mpe_value(*self._get_differentiable_inputs(
                w_tensor, ivs_tensor, *value_tensors))
        else:
            return self._reduce_mpe_inference_log(
                self._compute_reducible(
                    w_tensor,
                    ivs_tensor,
                    *value_tensors,
                    weighted=True,
                    dropconnect_keep_prob=dropconnect_keep_prob))

    @utils.lru_cache
    def _compute_mpe_path_common(self,
                                 reducible_tensor,
                                 counts,
                                 w_tensor,
                                 ivs_tensor,
                                 *input_tensors,
                                 sample=False,
                                 sample_prob=None,
                                 accumulate_weights_batch=False):
        """Common operations for computing the MPE path.

        Args:
            reducible_tensor (Tensor): A (weighted) ``Tensor`` of (log-)values of this node.
            counts (Tensor): A ``Tensor`` that contains the accumulated counts of the parents
                             of this node.
            w_tensor (Tensor):  A ``Tensor`` containing the (log-)value of the weights.
            ivs_tensor (Tensor): A ``Tensor`` containing the (log-)value of the IVs.
            input_tensors (list): A list of ``Tensor``s with outputs of the child nodes.
            log (bool): Whether the computation is in log-space or not
            sample (bool): Whether to sample the 'winner' of the max or not
            sample_prob (Tensor): A scalar ``Tensor`` indicating the probability of drawing
                a sample. If a sample is drawn, the probability for each index is given by the
                (log-)normalized probability as given by ``reducible_tensor``.
        Returns:
            A ``list`` of ``tuple``s [(MPE counts, input tensor), ...] where the first corresponds
            to the Weights of this node, the second corresponds to the IVs and the remaining
            tuples correspond to the nodes in ``self._values``.
        """
        sample_prob = utils.maybe_first(sample_prob, self._sample_prob)
        if sample:
            max_indices = self._reduce_sample_log(reducible_tensor,
                                                  sample_prob=sample_prob)
        else:
            max_indices = self._reduce_argmax(reducible_tensor)
        max_counts = utils.scatter_values(params=counts,
                                          indices=max_indices,
                                          num_out_cols=self._max_sum_size)
        max_counts_acc, max_counts_split = self._accumulate_and_split_to_children(
            max_counts, *input_tensors)
        if accumulate_weights_batch:
            max_counts = tf.reduce_sum(max_counts, axis=0, keepdims=False)
        return self._scatter_to_input_tensors(
            (max_counts, w_tensor),  # Weights
            (max_counts_acc, ivs_tensor),  # IVs
            *[(t, v)
              for t, v in zip(max_counts_split, input_tensors)])  # Values

    @utils.lru_cache
    def _accumulate_and_split_to_children(self, x, *input_tensors):
        """Accumulates the values in x over the op axis. Potentially also accumulates for every
        unique input if appropriate (e.g. in SumsLayer).

        Args:
            x (Tensor): A ``Tensor`` containing the values to accumulate and split among the
                        children.
            input_tensors (tuple): A ``tuple`` of ``Tensors`` holding the value of the children's
                                   outputs. These might be used in e.g. SumsLayer to determine
                                   unique inputs so that values can be accumulated before passing
                                   them downward.
        Returns:
            A ``tuple`` of size 2 with the accumulated values and a list of accumulated values
            corresponding to each input.
        """
        if self._num_sums > 1:
            x_acc = tf.reduce_sum(x, axis=self._op_axis)
        else:
            x_acc = tf.squeeze(x, axis=self._op_axis)

        _, _, *value_sizes = self.get_input_sizes()
        return x_acc, tf.split(x_acc, value_sizes, axis=self._op_axis)

    @utils.docinherit(OpNode)
    @utils.lru_cache
    def _compute_log_mpe_path(self,
                              counts,
                              w_tensor,
                              ivs_tensor,
                              *input_tensors,
                              use_unweighted=False,
                              add_random=None,
                              accumulate_weights_batch=False,
                              sample=False,
                              sample_prob=None,
                              dropconnect_keep_prob=None):
        weighted = not use_unweighted or any(v.node.is_var
                                             for v in self._values)
        reducible = self._compute_reducible(
            w_tensor,
            ivs_tensor,
            *input_tensors,
            weighted=weighted,
            dropconnect_keep_prob=dropconnect_keep_prob)
        if not weighted and self._num_sums > 1 and reducible.shape[
                self._op_axis].value == 1:
            reducible = tf.tile(reducible, (1, self._num_sums, 1))
        # Add random
        if add_random is not None:
            reducible += tf.random_uniform(tf.shape(reducible),
                                           minval=0.0,
                                           maxval=add_random,
                                           dtype=conf.dtype)

        return self._compute_mpe_path_common(
            reducible,
            counts,
            w_tensor,
            ivs_tensor,
            *input_tensors,
            accumulate_weights_batch=accumulate_weights_batch,
            sample=sample,
            sample_prob=sample_prob)

    @utils.lru_cache
    def _compute_log_gradient(self,
                              gradients,
                              w_tensor,
                              ivs_tensor,
                              *value_tensors,
                              accumulate_weights_batch=False,
                              dropconnect_keep_prob=None):
        """Computes gradient for log probabilities.

        Args:
            gradients (Tensor): A ``Tensor`` of shape [batch, num_sums] that contains the
                                accumulated backpropagated gradient coming from this node's parents.
            w_tensor (Tensor): A ``Tensor`` of shape [num_sums, max_sum_size] that contains the
                               weights corresponding to this node.
            ivs_tensor (Tensor): A ``Tensor`` of shape [batch, num_sums, max_sum_size] that
                                 corresponds to the IVs of this node.
            value_tensors (tuple): A ``tuple`` of ``Tensor``s that correspond to the values of the
                                   children of this node.
            accumulate_weights_batch (bool): A ``bool`` that marks whether the weight gradients should be
                                     summed over the batch axis.
        Returns:
            A ``tuple`` of gradients. Starts with weights, then IVs  and the remainder corresponds
            to ``value_tensors``.
        """

        reducible = self._compute_reducible(
            w_tensor,
            ivs_tensor,
            *value_tensors,
            weighted=True,
            dropconnect_keep_prob=dropconnect_keep_prob)

        # Below exploits the memoization since _reduce_marginal_inference_log will
        # always use keepdims=False, thus yielding the same tensor. One might otherwise
        # be tempted to use keepdims=True and omit expand_dims here...
        log_sum = tf.expand_dims(
            self._reduce_marginal_inference_log(reducible),
            axis=self._reduce_axis)

        # A number - (-inf) is undefined. In fact, the gradient in those cases should be zero
        log_sum = tf.where(tf.is_inf(log_sum), tf.zeros_like(log_sum), log_sum)
        w_grad = tf.expand_dims(
            gradients, axis=self._reduce_axis) * tf.exp(reducible - log_sum)

        value_grad_acc, value_grad_split = self._accumulate_and_split_to_children(
            w_grad)

        if accumulate_weights_batch:
            w_grad = tf.reduce_sum(w_grad, axis=0, keepdims=False)

        return self._scatter_to_input_tensors(
            (w_grad, w_tensor), (value_grad_acc, ivs_tensor),
            *[(t, v) for t, v in zip(value_grad_split, value_tensors)])

    def _get_flat_value_scopes(self, weight_scopes, ivs_scopes, *value_scopes):
        """Get a flat representation of the value scopes per sum.

        Args:
            weight_scopes (list): A list of ``Scope``s corresponding to the weights.
            ivs_scopes (list): A list of ``Scope``s corresponding to the IVs.
            value_scopes (tuple): A ``tuple`` of ``list``s of ``Scope``s corresponding to the
                                  scope lists of the children of this node.

        Returns:
            A tuple of flat value scopes corresponding to this node's output. The IVs scopes and
            the value scopes.
        """
        if not self._values:
            raise StructureError("%s is missing input values" % self)
        _, ivs_scopes, *value_scopes = self._gather_input_scopes(
            weight_scopes, ivs_scopes, *value_scopes)
        return list(
            chain.from_iterable(value_scopes)), ivs_scopes, value_scopes

    @utils.docinherit(OpNode)
    def _compute_scope(self, weight_scopes, ivs_scopes, *value_scopes):
        flat_value_scopes, ivs_scopes, *value_scopes = self._get_flat_value_scopes(
            weight_scopes, ivs_scopes, *value_scopes)
        if self._ivs:
            sublist_size = int(len(ivs_scopes) / self._num_sums)
            # Divide gathered ivs scopes into sublists, one per modelled Sum node.
            ivs_scopes_sublists = [
                ivs_scopes[i:i + sublist_size]
                for i in range(0, len(ivs_scopes), sublist_size)
            ]
        return [
            Scope.merge_scopes(
                flat_value_scopes +
                ivs_scopes_sublists[i] if self._ivs else flat_value_scopes)
            for i in range(self._num_sums)
        ]

    @utils.docinherit(OpNode)
    def _compute_valid(self, weight_scopes, ivs_scopes, *value_scopes):
        # If already invalid, return None
        if (any(s is None for s in value_scopes)
                or (self._ivs and ivs_scopes is None)):
            return None
        flat_value_scopes, ivs_scopes_, *value_scopes_ = self._get_flat_value_scopes(
            weight_scopes, ivs_scopes, *value_scopes)
        # IVs
        if self._ivs:
            # Verify number of IVs
            if len(ivs_scopes_) != len(flat_value_scopes) * self._num_sums:
                raise StructureError(
                    "Number of IVs (%s) and values (%s) does "
                    "not match for %s" %
                    (len(ivs_scopes_), len(flat_value_scopes) * self._num_sums,
                     self))
            # Check if scope of all IVs is just one and the same variable
            if len(Scope.merge_scopes(ivs_scopes_)) > self._num_sums:
                return None
        # Check sum for completeness wrt values
        first_scope = flat_value_scopes[0]
        if any(s != first_scope for s in flat_value_scopes[1:]):
            self.info("%s is not complete with input value scopes %s", self,
                      flat_value_scopes)
            return None

        return self._compute_scope(weight_scopes, ivs_scopes, *value_scopes)

    @property
    @utils.docinherit(OpNode)
    def _const_out_size(self):
        return True

    @utils.lru_cache
    def _prepare_component_wise_processing(self,
                                           w_tensor,
                                           ivs_tensor,
                                           *input_tensors,
                                           zero_prob_val=0.0):
        """Gathers inputs and combines them so that the resulting tensor can be reduced over the
        last axis to compute the (weighted) sums.

        Args:
            w_tensor (Tensor): A ``Tensor`` with the (log-)value of the weights of this node of
                               shape [num_sums, max_sum_size]
            ivs_tensor (Tensor): A ``Tensor`` with the (log-)value of the 'latent' ``IVs``.
            input_tensors (tuple): A tuple of ``Tensor``s  holding the value of the children of this
                                   node.
            zero_prob_val (float): The value of zero probability. This is important to know if some
                                   parts of the computation should be left out for masking.
        Returns:
            A tuple of size 3 containing: a weight ``Tensor`` that can be broadcast across sums, an
            IVs ``Tensor`` that can be applied component-wise to the sums and a ``Tensor`` that
            holds the unweighted values of the sum inputs of shape [batch, num_sums, max_sum_size].
        """
        w_tensor, ivs_tensor, *input_tensors = self._gather_input_tensors(
            w_tensor, ivs_tensor, *input_tensors)
        input_tensors = [
            tf.expand_dims(t, axis=self._op_axis) if len(t.shape) == 2 else t
            for t in input_tensors
        ]
        w_tensor = tf.expand_dims(w_tensor, axis=self._batch_axis)
        reducible_inputs = tf.concat(input_tensors, axis=self._reduce_axis)
        if ivs_tensor is not None:
            ivs_tensor = tf.reshape(ivs_tensor,
                                    shape=(-1, self._num_sums,
                                           self._max_sum_size))
        return w_tensor, ivs_tensor, reducible_inputs

    @utils.lru_cache
    def _reduce_marginal_inference_log(self, x):
        """Reduces a tensor for marginal log inference by log(sum(exp(x), axis=reduce_axis)).

        Args:
            x (Tensor): A ``Tensor`` of shape [batch, num_sums, max_sum_size] to reduce over the
                        last axis.

        Returns:
            A ``Tensor`` reduced over the last axis.
        """
        return tf.reduce_logsumexp(x, axis=self._reduce_axis, keepdims=False)

    @utils.lru_cache
    def _reduce_mpe_inference(self, x):
        """Reduces a tensor for MPE non-log inference by max(x, axis=reduce_axis)).

        Args:
            x (Tensor): A ``Tensor`` of shape [batch, num_sums, max_sum_size] to reduce over the
                        last axis.

        Returns:
            A ``Tensor`` reduced over the last axis.
        """
        return tf.reduce_max(x, axis=self._reduce_axis, keepdims=False)

    @utils.lru_cache
    def _reduce_mpe_inference_log(self, x):
        """Reduces a tensor for MPE log inference by max(x, axis=reduce_axis).

        Args:
            x (Tensor): A ``Tensor`` of shape [batch, num_sums, max_sum_size] to reduce over the
                        last axis.

        Returns:
            A ``Tensor`` reduced over the last axis.
        """
        return self._reduce_mpe_inference(x)

    @utils.lru_cache
    def _reduce_argmax(self, x):
        """Reduces a tensor by argmax(x, axis=reduce_axis)).

        Args:
            x (Tensor): A ``Tensor`` of shape [batch, num_sums, max_sum_size] to reduce over the
                        last axis.

        Returns:
            A ``Tensor`` reduced over the last axis.
        """
        if conf.argmax_zero:
            # If true, uses TensorFlow's argmax directly, yielding a bias towards the zeroth index
            return tf.argmax(x, axis=self._reduce_axis)

        # Return random index in case multiple values equal max
        x_max = tf.expand_dims(self._reduce_mpe_inference(x),
                               self._reduce_axis)
        x_eq_max = tf.to_float(tf.equal(x, x_max))
        if self._masked:
            x_eq_max *= tf.expand_dims(tf.to_float(self._build_mask()),
                                       axis=self._batch_axis)
        x_eq_max /= tf.reduce_sum(x_eq_max,
                                  axis=self._reduce_axis,
                                  keepdims=True)

        return tfd.Categorical(probs=x_eq_max,
                               name="StochasticArgMax",
                               dtype=tf.int64).sample()

    @utils.lru_cache
    def _reduce_sample_log(self, logits, sample_prob=None):
        """Samples a tensor with log likelihoods, i.e. sample(x, axis=reduce_axis)).

        Args:
            logits (Tensor): A ``Tensor`` of shape [batch, num_sums, max_sum_size] to reduce over
                the last axis.
            sample_prob (Tensor or float): A ``Tensor`` or float indicating the probability of
                taking a sample.

        Returns:
            A ``Tensor`` reduced over the last axis.
        """

        # Categorical eventually uses non-log probabilities, so here we reuse as much as we can to
        # predetermine it
        def _sample():
            logits_sum = self._reduce_marginal_inference_log(logits)
            log_prob = tf.exp(
                logits - tf.expand_dims(logits_sum, axis=self._reduce_axis))
            return tfd.Categorical(probs=tf.exp(log_prob),
                                   dtype=tf.int64).sample()

        def _select_sample_or_argmax(x):
            mask = tfd.Bernoulli(
                probs=sample_prob,
                dtype=tf.bool).sample(sample_shape=tf.shape(x))
            return tf.where(mask, x, self._reduce_argmax(logits))

        if sample_prob is not None:
            if isinstance(sample_prob, (float, int)):
                if sample_prob < 0 or sample_prob > 1:
                    raise ValueError(
                        "{}: Sample probability should be between 0 and 1. Got {} "
                        "instead.".format(self, sample_prob))
                if sample_prob != 0:
                    sample_op = _sample()
                    if sample_prob == 1.0:
                        return sample_op
                    return _select_sample_or_argmax(sample_op)

            return _select_sample_or_argmax(_sample())
        else:
            return _sample()
Beispiel #16
0
class GDLearning:
    """Assembles TF operations performing Gradient Descent learning of an SPN.

    Args:
        value_inference_type (InferenceType): The inference type used during the
            upwards pass through the SPN. Ignored if ``mpe_path`` is given.
        learning_rate (float): Learning rate parameter used for updating SPN weights.
        learning_task_type (LearningTaskType): Learning type used while learning.
        learning_method (LearningMethodType): Learning method type, can be either generative
            (LearningMethodType.GENERATIVE) or discriminative (LearningMethodType.DISCRIMINATIVE).
        marginalizing_root (Sum, ParallelSums, SumsLayer): A sum node without IndicatorLeafs attached to
            it (or IndicatorLeafs with a fixed no-evidence feed). If it is omitted here, the node
            will constructed internally once needed.
        name (str): The name given to this instance of GDLearning.
    """

    __logger = get_logger()

    def __init__(self,
                 root,
                 value=None,
                 value_inference_type=None,
                 learning_task_type=LearningTaskType.SUPERVISED,
                 learning_method=LearningMethodType.DISCRIMINATIVE,
                 learning_rate=1e-4,
                 marginalizing_root=None,
                 name="GDLearning",
                 global_step=None,
                 linear_w_minimum=1e-2):

        if learning_task_type == LearningTaskType.UNSUPERVISED and \
                learning_method == LearningMethodType.DISCRIMINATIVE:
            raise ValueError(
                "It is not possible to do unsupervised learning discriminatively."
            )

        self._root = root
        self._marginalizing_root = marginalizing_root

        if value is not None and isinstance(value, LogValue):
            self._log_value = value
        else:
            if value is not None:
                GDLearning.__logger.warn(
                    "{}: Value instance is ignored since the current implementation does "
                    "not support gradients with non-log inference. Using a LogValue instance "
                    "instead.".format(name))
            self._log_value = LogValue(value_inference_type)
        self._learning_rate = learning_rate
        self._learning_task_type = learning_task_type
        self._learning_method = learning_method
        self._name = name
        self._global_step = global_step
        self._linear_w_minimum = linear_w_minimum

    def loss(self, learning_method=None, reduce_fn=tf.reduce_mean):
        """Assembles main objective operations. In case of generative learning it will select
        the MLE objective, whereas in discriminative learning it selects the cross entropy.

        Args:
            learning_method (LearningMethodType): The learning method (can be either generative
                or discriminative).

        Returns:
            An operation to compute the main loss function.
        """
        learning_method = learning_method or self._learning_method
        if learning_method == LearningMethodType.GENERATIVE:
            return self.negative_log_likelihood(reduce_fn=reduce_fn)
        return self.cross_entropy_loss(reduce_fn=reduce_fn)

    def learn(self,
              loss=None,
              optimizer=None,
              post_gradient_ops=True,
              name="LearnGD"):
        """Assemble TF operations performing GD learning of the SPN. This includes setting up
        the loss function (with regularization), setting up the optimizer and setting up
        post gradient-update ops.

        Args:
            loss (Tensor): The operation corresponding to the loss to minimize.
            optimizer (tf.train.Optimizer): A TensorFlow optimizer to use for minimizing the loss.
            post_gradient_ops (bool): Whether to use post-gradient ops such as normalization.

        Returns:
            A tuple of grouped update Ops and a loss Op.
        """
        if self._learning_task_type == LearningTaskType.SUPERVISED and self._root.latent_indicators is None:
            raise StructureError(
                "{}: the SPN rooted at {} does not have a latent IndicatorLeaf node, so cannot "
                "setup conditional class probabilities.".format(
                    self._name, self._root))

        # If a loss function is not provided, define the loss function based
        # on learning-type and learning-method
        with tf.name_scope(name):
            with tf.name_scope("Loss"):
                if loss is None:
                    if self._learning_method == LearningMethodType.GENERATIVE:
                        loss = self.negative_log_likelihood()
                    else:
                        loss = self.cross_entropy_loss()
            # Assemble TF ops for optimizing and weights normalization
            with tf.name_scope("ParameterUpdate"):
                minimize = optimizer.minimize(loss=loss)
                if post_gradient_ops:
                    return self.post_gradient_update(minimize), loss
                else:
                    return minimize, loss

    def post_gradient_update(self, update_op):
        """Constructs post-parameter update ops such as normalization of weights and clipping of
        scale parameters of NormalLeaf nodes.

        Args:
            update_op (Tensor): A Tensor corresponding to the parameter update.

        Returns:
            An updated operation where the post-processing has been ensured by TensorFlow's control
            flow mechanisms.
        """
        with tf.name_scope("PostGradientUpdate"):

            # After applying gradients to weights, normalize weights
            with tf.control_dependencies([update_op]):
                weight_norm_ops = []

                def fun(node):
                    if node.is_param:
                        weight_norm_ops.append(
                            node.normalize(
                                linear_w_minimum=self._linear_w_minimum))

                    if isinstance(node,
                                  LocationScaleLeaf) and node._trainable_scale:
                        weight_norm_ops.append(
                            tf.assign(
                                node.scale_variable,
                                tf.maximum(node.scale_variable,
                                           node._min_scale)))

                with tf.name_scope("WeightNormalization"):
                    traverse_graph(self._root, fun=fun)
            return tf.group(*weight_norm_ops, name="weight_norm")

    def cross_entropy_loss(self,
                           name="CrossEntropy",
                           reduce_fn=tf.reduce_mean):
        """Sets up the cross entropy loss, which is equivalent to -log(p(Y|X)).

        Args:
            name (str): Name of the name scope for the Ops defined here
            reduce_fn (Op): An operation that reduces the losses for all samples to a scalar.

        Returns:
            A Tensor corresponding to the cross-entropy loss.
        """
        with tf.name_scope(name):
            log_prob_data_and_labels = LogValue().get_value(self._root)
            log_prob_data = self._log_likelihood()
            return -reduce_fn(log_prob_data_and_labels - log_prob_data)

    def negative_log_likelihood(self,
                                name="NegativeLogLikelihood",
                                reduce_fn=tf.reduce_mean):
        """Returns the maximum (log) likelihood estimate loss function which corresponds to
        -log(p(X)) in the case of unsupervised learning or -log(p(X,Y)) in the case of supservised
        learning.

        Args:
            name (str): The name for the name scope to use
            reduce_fn (function): An function that returns an operation that reduces the losses for
                all samples to a scalar.
        Returns:
            A Tensor corresponding to the MLE loss
        """
        with tf.name_scope(name):
            if self._learning_task_type == LearningTaskType.UNSUPERVISED:
                if self._root.latent_indicators is not None:
                    likelihood = self._log_likelihood()
                else:
                    likelihood = self._log_value.get_value(self._root)
            elif self._root.latent_indicators is None:
                raise StructureError(
                    "Root should have latent indicator node when doing supervised "
                    "learning.")
            else:
                likelihood = self._log_value.get_value(self._root)
            return -reduce_fn(likelihood)

    def _log_likelihood(self):
        """Computes log(p(X)) by creating a copy of the root node without latent indicators.

        Returns:
            A Tensor of shape [batch, 1] corresponding to the log likelihood of the data.
        """
        if isinstance(self._root, BaseSum):
            marginalizing_root = self._marginalizing_root or Sum(
                *self._root.values, weights=self._root.weights)
        else:
            marginalizing_root = self._marginalizing_root or BlockSum(
                self._root.values[0],
                weights=self._root.weights,
                num_sums_per_block=1)
        return self._log_value.get_value(marginalizing_root)
Beispiel #17
0
class Dataset(ABC):
    """An abstract class defining the interface of a dataset.

    Args:
        num_vars (int): Number of variables in each data sample.
        num_vals (int or list of int): Number of values of each variable. Can be
            a single value or a list of values, one for each of ``num_vars``
            variables. Use ``None``, to indicate that a variable is continuous,
            in the range ``[0, 1]``.
        num_labels (int): Number of labels for each data sample.
        num_epochs (int): Number of epochs of produced data.
        batch_size (int): Size of a single batch.
        shuffle (bool): Shuffle data within each epoch.
        shuffle_batch (bool): Shuffle data when generating batches.
        min_after_dequeue (int): Min number of elements in the data queue after
                                 each dequeue. This is the minimum number of
                                 elements from which the shuffled batch will
                                 be drawn. Relevant only and must be set if
                                 ``shuffle_batch`` is ``True``.
        num_threads (int): Number of threads enqueuing the data queue. If
                           larger than ``1``, the performance will be better,
                           but examples might not be in order even if
                           ``shuffle_batch`` is ``False``. If ``shuffle_batch``
                           is ``True``, this might lead to examples repeating in
                           the same batch.
        allow_smaller_final_batch(bool): If ``False``, the last batch will be
                                         omitted if it has less elements than
                                         ``batch_size``.
        seed (int): Optional. Seed used when shuffling.
    """

    __logger = get_logger()
    __info = __logger.info

    def __init__(self,
                 num_vars,
                 num_vals,
                 num_labels,
                 num_epochs,
                 batch_size,
                 shuffle,
                 shuffle_batch,
                 min_after_dequeue=None,
                 num_threads=1,
                 allow_smaller_final_batch=False,
                 seed=None):
        if not isinstance(num_vars, int) or num_vars < 1:
            raise ValueError("num_vars must be a positive integer")
        self._num_vars = num_vars
        if isinstance(num_vals, list):
            if len(num_vals) != num_vars:
                raise ValueError("num_vals must have num_vars elements")
            if any((i is not None) and (not isinstance(i, int) or i < 1)
                   for i in num_vals):
                raise ValueError(
                    "num_vals values must be a positive integers or None")
            # If all elements are the same, just convert to int
            if num_vals.count(num_vals[0]) == len(num_vals):
                self._num_vals = num_vals[0]
            else:
                self._num_vals = num_vals
        else:
            if ((num_vals is not None)
                    and (not isinstance(num_vals, int) or num_vals < 1)):
                raise ValueError("num_vals must be a positive integer or None")
            self._num_vals = num_vals
        if not isinstance(num_labels, int) or num_labels < 0:
            raise ValueError("num_labels must be an integer >= 0")
        self._num_labels = num_labels
        if not isinstance(num_epochs, int) or num_epochs < 1:
            raise ValueError("num_epochs must be a positive integer")
        self._num_epochs = num_epochs
        if not isinstance(batch_size, int) or batch_size < 1:
            raise ValueError("batch_size must be a positive integer")
        self._batch_size = batch_size
        if shuffle_batch and not shuffle:
            raise RuntimeError("Batch shuffling should not be enabled "
                               "when shuffle is False.")
        if shuffle_batch and min_after_dequeue is None:
            raise RuntimeError("min_after_dequeue must be set if batch "
                               "shuffling is enabled.")
        self._shuffle = shuffle
        self._shuffle_batch = shuffle_batch
        self._min_after_dequeue = min_after_dequeue
        self._num_threads = num_threads
        self._allow_smaller_final_batch = allow_smaller_final_batch
        if seed is not None and (not isinstance(seed, int) or seed < 1):
            raise ValueError("seed must be None or a positive integer")
        self._seed = seed
        self._name_scope = None

    @property
    def num_vars(self):
        """int: Number of variables in each data sample."""
        return self._num_vars

    @property
    def num_vals(self):
        """int or list of int: Number of values of each variable.

        If each variable has the same number of values, the value is returned as
        a single integer or ``None``. If the values differ, a list of lenght
        ``num_vars`` is returned, where each value represents the number of
        values for a single variable.

        Value can be either an integer or ``None`` indicating that a variable is
        continuous, in the range ``[0, 1]``.
        """
        return self._num_vals

    @property
    def num_labels(self):
        """int: Number of labels for each data sample."""
        return self._num_labels

    @property
    def num_epochs(self):
        """int: Number of epochs of produced data."""
        return self._num_epochs

    @property
    def batch_size(self):
        """int: Size of a single batch."""
        return self._batch_size

    @property
    def shuffle(self):
        """bool: ``True`` if provided data is shuffled."""
        return self._shuffle

    @property
    def seed(self):
        """int: Seed used when shuffling."""
        return self._seed

    @property
    def allow_smaller_final_batch(self):
        """bool: If ``False``, the last batch is omitted if it has less
        elements than ``batch_size``."""
        return self._allow_smaller_final_batch

    def get_data(self):
        """Get an operation obtaining batches of data from the dataset.

        Returns:
            A tensor or a list of tensors with the batch data.
        """
        self.__info("Building dataset operations")
        with tf.name_scope("Dataset") as self._name_scope:
            raw_data = self.generate_data()
            proc_data = self.process_data(raw_data)
            return self.batch_data(proc_data)

    @abstractmethod
    def generate_data(self):
        """Assemble a TF operation generating the next data sample.

        Returns:
            A list of tensors with a single data sample.
        """
        pass

    @abstractmethod
    def process_data(self, data):
        """Assemble a TF operation processing a data sample.

        Args:
            data: A list of tensors with a single data sample.

        Returns:
            A list of tensors with a single data sample.
        """
        pass

    def batch_data(self, data):
        """Assemble a TF operation producing batches of data samples.

        Args:
            data: A list of tensors or a dictionary of tensors with
                  a single data sample. If the list of tensors contains
                  only one element, this function returns a tensor.
                  Otherwise, it returns a list of dictionary of tensors.

        Returns:
            A tensor, a list of tensors or a dictionary of tensors with a
            batch of data.
        """
        if self._shuffle_batch:
            # If len(data) is 1, batch will be a tensor
            # If len(data) > 0, batch will be a list of tensors
            batch = tf.train.shuffle_batch(
                data,
                batch_size=self._batch_size,
                num_threads=self._num_threads,
                seed=self._seed,
                capacity=(self._min_after_dequeue +
                          (self._num_threads + 1) * self._batch_size),
                min_after_dequeue=self._min_after_dequeue,
                allow_smaller_final_batch=self._allow_smaller_final_batch)
        else:
            # If len(data) is 1, batch will be a tensor
            # If len(data) > 0, batch will be a list of tensors
            batch = tf.train.batch(
                data,
                batch_size=self._batch_size,
                num_threads=self._num_threads,
                capacity=(self._num_threads + 1) * self._batch_size,
                allow_smaller_final_batch=self._allow_smaller_final_batch)
        return batch

    def read_all(self):
        """Read all data (all batches and epochs) from the dataset into numpy
        arrays.

        Returns:
            An array, a list of arrays or a dictionary of arrays with all the
            data in the dataset.
        """
        # Read all batches in internal graph
        batches = []
        with tf.Graph().as_default():
            data = self.get_data()
            with session() as (sess, run):
                while run():
                    out = sess.run(data)
                    batches.append(out)
        # Concatenate
        if isinstance(batches[0], list):
            return [
                np.concatenate([b[key] for b in batches])
                for key in range(len(batches[0]))
            ]
        else:
            return np.concatenate(batches)

    def write_all(self, writer):
        """Write all data (all batches and epochs) from the dataset using the
        given writer. Each batch is written using a separate ``write()`` call on
        the writer. Therefore, even dataset that do not fit in memory can be
        written this way.

        Args:
            writer (DataWriter): The data writer to use.
        """
        self.__info("Writing all data from %s to %s" %
                    (type(self).__name__, type(writer).__name__))
        with tf.Graph().as_default():
            data = self.get_data()
            with session() as (sess, run):
                i = 0
                while run():
                    out = sess.run(data)
                    i += 1
                    self.__info("Writing batch %d" % i)
                    if not isinstance(out, list):  # Convert to list
                        out = [out]
                    writer.write(*out)
Beispiel #18
0
class PermuteProducts(OpNode):
    """A node representing multiple products, permuted over the input space, in
       an SPN.

    Args:
        *values (input_like): Inputs providing input values to this node.
            See :meth:`~libspn.Input.as_input` for possible values. The only
            criterion for the input is that all inputs, in the list, should
            have the same dimention.
        name (str): Name of the node.
    """

    logger = get_logger()
    info = logger.info

    def __init__(self, *values, name="PermuteProducts"):
        self._values = []
        super().__init__(inference_type=InferenceType.MARGINAL, name=name)
        self.set_values(*values)

        self.create_products()

    def serialize(self):
        data = super().serialize()
        data['input_sizes'] = self._input_sizes
        data['num_inputs'] = self._num_inputs
        data['values'] = [(i.node.name, i.indices) for i in self._values]
        return data

    def deserialize(self, data):
        super().deserialize(data)
        self.set_values()

    def deserialize_inputs(self, data, nodes_by_name):
        super().deserialize_inputs(data, nodes_by_name)
        self._values = tuple(
            Input(nodes_by_name[nn], i) for nn, i in data['values'])
        self.create_products(input_sizes=data['input_sizes'],
                             num_inputs=data['num_inputs'])

    @property
    @utils.docinherit(OpNode)
    def inputs(self):
        return self._values

    @property
    def num_prods(self):
        """int: Number of Product ops modelled by this node."""
        return self._num_prods

    @property
    def values(self):
        """list of Input: List of value inputs."""
        return self._values

    def set_values(self, *values):
        """Set the inputs providing input values to this node. If no arguments
        are given, all existing value inputs get disconnected.

        Args:
            *values (input_like): Inputs providing input values to this node.
                See :meth:`~libspn.Input.as_input` for possible values.
        """
        self._values = self._parse_inputs(*values)

    def create_products(self, input_sizes=None, num_inputs=None):
        """Based on the number and size of inputs connected to this node, model
        products by permuting over the inputs.
        """
        if not self._values:
            raise StructureError("%s is missing input values." % self)

        self._input_sizes = input_sizes if input_sizes is not None \
            else list(self.get_input_sizes())
        self._num_inputs = num_inputs if num_inputs is not None \
            else len(self._input_sizes)

        # Calculate number of products this node would model.
        if self._num_inputs == 1:
            self._num_prods = 1
        else:
            self._num_prods = int(np.prod(self._input_sizes))

        # Create indices by permuting over the input space, such that inputs
        # for the products can be generated by gathering from concatenated
        # input values.
        self._permuted_indices = self.permute_indices(self._input_sizes)

    def permute_indices(self, input_sizes):
        """Create indices by permuting over the inputs, such that inputs for each
        product modeled by this node can be generated by gathering from concatenated
        values of the node.

        Args:
            inputs_sizes (list): List of input sizes.

        Return:
            permuted_indices (list): List of indices for gathring inputs of all
                                     the product nodes modeled by this Op.
        """
        ind_range = np.cumsum([0] + input_sizes)
        ind_list = list(
            product(*[
                range(start, stop)
                for start, stop in zip(ind_range, ind_range[1:])
            ]))

        return list(chain(*ind_list))

    def add_values(self, *values):
        """Add more inputs providing input values to this node. Then remodel the
        products based on the newly added inputs.

        Args:
            *values (input_like): Inputs providing input values to this node.
                See :meth:`~libspn.Input.as_input` for possible values.
        """
        self._values = self._values + self._parse_inputs(*values)
        self.create_products()

    @property
    def _const_out_size(self):
        return True

    @utils.lru_cache
    def _compute_out_size(self, *input_out_sizes):
        return self._num_prods

    def _compute_scope(self, *value_scopes):
        if not self._values:
            raise StructureError("%s is missing input values." % self)
        value_scopes = self._gather_input_scopes(*value_scopes)
        if self._num_prods == 1:
            return [Scope.merge_scopes(chain.from_iterable(value_scopes))]

        value_scopes_list = [
            Scope.merge_scopes(pvs)
            for pvs in product(*[vs for vs in value_scopes])
        ]

        return value_scopes_list

    def _compute_valid(self, *value_scopes):
        if not self._values:
            raise StructureError("%s is missing input values." % self)
        value_scopes_ = self._gather_input_scopes(*value_scopes)
        # If already invalid, return None
        if any(s is None for s in value_scopes_):
            return None
        if self._num_prods == 1:
            for s1, s2 in combinations(chain(*value_scopes_), 2):
                if s1 & s2:
                    PermuteProducts.info(
                        "%s is not decomposable with input value "
                        "scopes %s", self, value_scopes_[:10])
                    return None

        # Check product decomposability
        for perm_val_scope in product(*value_scopes_):
            for s1, s2 in combinations(perm_val_scope, 2):
                if s1 & s2:
                    PermuteProducts.info("%s is not decomposable", self)
                    return None
        return self._compute_scope(*value_scopes)

    @utils.lru_cache
    def _compute_value_common(self, *value_tensors):
        """Common actions when computing value."""
        # Check inputs
        if not self._values:
            raise StructureError("%s is missing input values." % self)
        # Prepare values
        value_tensors = self._gather_input_tensors(*value_tensors)
        if len(value_tensors) > 1:
            values = tf.concat(values=value_tensors, axis=1)
        else:
            values = value_tensors[0]
        if self._num_prods > 1:
            # Gather values based on permuted_indices
            permuted_values = tf.gather(values, self._permuted_indices, axis=1)

            # Shape of values tensor = [Batch, (num_prods * num_vals)]
            # First, split the values tensor into 'num_prods' smaller tensors.
            # Then pack the split tensors together such that the new shape
            # of values tensor = [Batch, num_prods, num_vals]
            reshape = (-1, self._num_prods,
                       int(permuted_values.shape[1].value / self._num_prods))
            reshaped_values = tf.reshape(permuted_values, shape=reshape)
            return reshaped_values
        else:
            return values

    @utils.lru_cache
    def _compute_log_value(self, *value_tensors):
        values = self._compute_value_common(*value_tensors)

        # Wrap the log value with its custom gradient
        @tf.custom_gradient
        def log_value(*value_tensors):
            # Defines gradient for the log value
            def gradient(gradients):
                scattered_grads = self._compute_log_mpe_path(
                    gradients, *value_tensors)
                return [sg for sg in scattered_grads if sg is not None]

            return tf.reduce_sum(
                values,
                axis=-1,
                keepdims=(False if self._num_prods > 1 else True)), gradient

        if conf.custom_gradient:
            return log_value(*value_tensors)
        else:
            return tf.reduce_sum(
                values,
                axis=-1,
                keep_dims=(False if self._num_prods > 1 else True))

    def _compute_log_mpe_value(self, *value_tensors):
        return self._compute_log_value(*value_tensors)

    @utils.lru_cache
    def _compute_log_mpe_path(self,
                              counts,
                              *value_values,
                              use_unweighted=False,
                              sample=False,
                              sample_prob=None):
        # Path per product node is calculated by permuting backwards to the
        # input nodes, then adding the appropriate counts per input, and then
        # scattering the summed counts to value inputs

        # Check inputs
        if not self._values:
            raise StructureError("%s is missing input values." % self)

        def permute_counts(input_sizes):
            # Function that permutes count values, backward to inputs.
            counts_indices_list = []

            def range_with_blocksize(start, stop, block_size, step):
                # A function that produces an arithmetic progression (Similar to
                # Python's range() function), but for a given block-size of
                # consecutive numbers.
                # E.g: range_with_blocksize(start=0, stop=20, block_size=3, step=5)
                # = [0, 1, 2, 5, 6, 7, 10, 11, 12, 15, 16, 17]
                counts_indices = []
                it = 0
                low = start
                high = low + block_size
                while low < stop:
                    counts_indices = counts_indices + list(range(low, high))
                    it += 1
                    low = start + (it * step)
                    high = low + block_size

                return counts_indices

            for inp, inp_size in enumerate(input_sizes):
                block_size = int(self._num_prods /
                                 np.prod(input_sizes[:inp + 1]))
                step = int(np.prod(input_sizes[inp:]))
                for i in range(inp_size):
                    start = i * block_size
                    stop = self._num_prods - (block_size * (inp_size - i - 1))
                    counts_indices_list.append(
                        range_with_blocksize(start, stop, block_size, step))

            return counts_indices_list

        if (len(self._input_sizes) > 1):
            permuted_indices = permute_counts(self._input_sizes)
            summed_counts = tf.reduce_sum(utils.gather_cols_3d(
                counts, permuted_indices),
                                          axis=-1)
            processed_counts_list = tf.split(summed_counts,
                                             self._input_sizes,
                                             axis=-1)
        else:  # For single input case, i.e, when _num_prods = 1
            summed_counts = self._input_sizes[0] * [counts]
            processed_counts_list = [tf.concat(values=summed_counts, axis=-1)]

        # Zip lists of processed counts and value_values together for scattering
        value_counts = zip(processed_counts_list, value_values)

        return self._scatter_to_input_tensors(*value_counts)

    def _compute_log_gradient(self, gradients, *value_values):
        return self._compute_log_mpe_path(gradients, *value_values)

    def disconnect_inputs(self):
        self._values = None
Beispiel #19
0
class BaseSum(OpNode, abc.ABC):
    logger = get_logger()
    info = logger.info
    """An abstract node representing sums in an SPN.

    Args:
        *values (input_like): Inputs providing input values to this node.
            See :meth:`~libspn.Input.as_input` for possible values.
        num_sums (int): Number of Sum ops modelled by this node.
        sum_sizes (list): A list of ints corresponding to the sizes of each sum. If both num_sums
                          and sum_sizes are given, we should have len(sum_sizes) == num_sums.
        batch_axis (int): The index of the batch axis.
        op_axis (int): The index of the op axis that contains the individual sums being modeled.
        reduce_axis (int): The axis over which to perform summing (or max for MPE)
        weights (input_like): Input providing weights node to this sum node.
            See :meth:`~libspn.Input.as_input` for possible values. If set
            to ``None``, the input is disconnected.
        latent_indicators (input_like): Input providing IndicatorLeafs of an explicit latent variable
            associated with this sum node. See :meth:`~libspn.Input.as_input`
            for possible values. If set to ``None``, the input is disconnected.
        name (str): Name of the node.

    Attributes:
        inference_type(InferenceType): Flag indicating the preferred inference
                                       type for this node that will be used
                                       during value calculation and learning.
                                       Can be changed at any time and will be
                                       used during the next inference/learning
                                       op generation.
    """
    def __init__(self,
                 *values,
                 num_sums,
                 weights=None,
                 latent_indicators=None,
                 sum_sizes=None,
                 inference_type=InferenceType.MARGINAL,
                 batch_axis=0,
                 op_axis=1,
                 reduce_axis=2,
                 masked=False,
                 sample_prob=None,
                 name="Sum"):
        super().__init__(inference_type=inference_type, name=name)

        self.set_values(*values)
        self.set_weights(weights)
        self.set_latent_indicators(latent_indicators)

        # Initialize the number of sums and the sum sizes
        self._reset_sum_sizes(num_sums=num_sums, sum_sizes=sum_sizes)

        # Set the axes
        self._batch_axis = batch_axis
        self._op_axis = op_axis
        self._reduce_axis = reduce_axis

        # Set whether this instance is masked (e.g. SumsLayer)
        self._masked = masked

        # Set the sampling probability and sampling type
        self._sample_prob = sample_prob

    def _get_sum_sizes(self, num_sums):
        """Computes a list of sum sizes given the number of sums and the currently attached input
        nodes.

        Args:
            num_sums (int): The number of sums modeled by this node.
        Returns:
            A list of sum sizes, where the i-th element corresponds to the size of the i-th sum.
        """
        input_sizes = self.get_input_sizes()
        num_values = sum(input_sizes[2:])  # Skip latent_indicators, weights
        return num_sums * [num_values]

    def _build_mask(self):
        """Constructs mask that could be used to cancel out 'columns' that are padded as a result of
        varying reduction sizes. Returns a Boolean mask.

        Returns:
            By default the sums are not masked, returns ``None``
        """
        return None

    @utils.docinherit(OpNode)
    def serialize(self):
        data = super().serialize()
        data['values'] = [(i.node.name, i.indices) for i in self._values]
        if self._weights:
            data['weights'] = (self._weights.node.name, self._weights.indices)
        if self._latent_indicators:
            data['latent_indicators'] = (self._latent_indicators.node.name,
                                         self._latent_indicators.indices)
        data['num_sums'] = self._num_sums
        data['sum_sizes'] = self._sum_sizes
        data['op_axis'] = self._op_axis
        data['reduce_axis'] = self._reduce_axis
        data['batch_axis'] = self._batch_axis
        data['sample_prob'] = self._sample_prob
        return data

    @utils.docinherit(OpNode)
    def deserialize(self, data):
        super().deserialize(data)
        self.set_values()
        self.set_weights()
        self.set_latent_indicators()
        self._sample_prob = data['sample_prob']
        self._num_sums = data['num_sums']
        self._sum_sizes = data['sum_sizes']
        self._max_sum_size = max(self._sum_sizes) if self._sum_sizes else 0
        self._batch_axis = data['batch_axis']
        self._op_axis = data['op_axis']
        self._reduce_axis = data['reduce_axis']

    def disconnect_inputs(self):
        self._latent_indicators = self._weights = self._values = None

    @utils.docinherit(OpNode)
    def deserialize_inputs(self, data, nodes_by_name):
        super().deserialize_inputs(data, nodes_by_name)
        self._values = tuple(
            Input(nodes_by_name[nn], i) for nn, i in data['values'])
        weights = data.get('weights', None)
        if weights:
            self._weights = Input(nodes_by_name[weights[0]], weights[1])
        latent_indicators = data.get('latent_indicators', None)
        if latent_indicators:
            self._latent_indicators = Input(
                nodes_by_name[latent_indicators[0]], latent_indicators[1])

    @property
    @utils.docinherit(OpNode)
    def inputs(self):
        return (self._weights, self._latent_indicators) + self._values

    @property
    def weights(self):
        """Input: Weights input."""
        return self._weights

    def set_weights(self, weights=None):
        """Set the weights input.

        Args:
            weights (input_like): Input providing weights node to this sum node.
                See :meth:`~libspn.Input.as_input` for possible values. If set
                to ``None``, the input is disconnected.
        """
        weights, = self._parse_inputs(weights)
        if weights and not isinstance(weights.node, Weights):
            raise StructureError("%s is not Weights" % weights.node)
        self._weights = weights

    def _reset_sum_sizes(self, num_sums=None, sum_sizes=None):
        """Resets the sizes and number of sums. If number of sums is specified, it will take that
        value, otherwise it will take the value that is already set. If sum_sizes is specified
        it will take that value, otherwise it will infer that using
        :meth:`~libspn.BaseSum._get_sum_sizes`. Finally, it also sets the maximum sum size.

        Args:
            num_sums (int): Number of sums modeled by this ``Node``.
            sum_sizes (int): A list of sum sizes with as many ``int``s as there are sums modeled.
        """
        self._num_sums = num_sums or self._num_sums
        self._sum_sizes = sum_sizes or self._get_sum_sizes(self._num_sums)
        self._max_sum_size = max(self._sum_sizes) if self._sum_sizes else 0

    @property
    def latent_indicators(self):
        """Input: IndicatorLeafs input."""
        return self._latent_indicators

    def set_latent_indicators(self, latent_indicators=None):
        """Set the IndicatorLeafs input.

        latent_indicators (input_like): Input providing IndicatorLeaf of an explicit latent variable
            associated with this sum node. See :meth:`~libspn.Input.as_input`
            for possible values. If set to ``None``, the input is disconnected.
        """
        self._latent_indicators, = self._parse_inputs(latent_indicators)

    @property
    def values(self):
        """list of Input: List of value inputs."""
        return self._values

    @property
    def num_sums(self):
        """int: The number of sums modeled by this node. """
        return self._num_sums

    @property
    def sum_sizes(self):
        """list of int: A list of the sum sizes. """
        return self._sum_sizes

    def set_values(self, *values):
        """Set the inputs providing input values to this node. If no arguments
        are given, all existing value inputs get disconnected.

        Args:
            *values (input_like): Inputs providing input values to this node.
                See :meth:`~libspn.Input.as_input` for possible values.
        """
        self._values = self._parse_inputs(*values)

    def add_values(self, *values):
        """Add more inputs providing input values to this node.

        Args:
            *values (input_like): Inputs providing input values to this node.
                See :meth:`~libspn.Input.as_input` for possible values.
        """
        self._values += self._parse_inputs(*values)
        self._reset_sum_sizes()

    def generate_weights(self,
                         initializer=tf.initializers.constant(1.0),
                         trainable=True,
                         input_sizes=None,
                         log=False,
                         name=None):
        """Generate a weights node matching this sum node and connect it to
        this sum.

        The function calculates the number of weights based on the number
        of input values of this sum. Therefore, weights should be generated
        once all inputs are added to this node.

        Args:
            initializer: Initial value of the weights.
            trainable (bool): See :class:`~libspn.Weights`.
            input_sizes (list of int): Pre-computed sizes of each input of
                this node.  If given, this function will not traverse the graph
                to discover the sizes.
            log (bool): If "True", the weights are represented in log space.
            name (str): Name of the weighs node. If ``None`` use the name of the
                        sum + ``_Weights``.

        Return:
            Weights: Generated weights node.
        """
        if not self._values:
            raise StructureError("%s is missing input values" % self)
        if name is None:
            name = self._name + "_Weights"
        # Count all input values
        if input_sizes:
            num_values = sum(
                input_sizes[2:])  # Skip latent_indicators, weights
        else:
            num_values = max(self._sum_sizes)
        # Generate weights
        weights = Weights(initializer=initializer,
                          num_weights=num_values,
                          num_sums=self._num_sums,
                          log=log,
                          trainable=trainable,
                          name=name)
        self.set_weights(weights)
        return weights

    def generate_latent_indicators(self, feed=None, name=None):
        """Generate an IndicatorLeaf node matching this sum node and connect it to
        this sum.

        IndicatorLeafs should be generated once all inputs are added to this node,
        otherwise the number of IndicatorLeafs will be incorrect.

        Args:
            feed (Tensor): See :class:`~libspn.IndicatorLeaf`.
            name (str): Name of the IndicatorLeaf node. If ``None`` use the name of the
                        sum + ``_IndicatorLeaf``.

        Return:
            IndicatorLeaf: Generated IndicatorLeaf node.
        """
        if not self._values:
            raise StructureError("%s is missing input values" % self)
        if name is None:
            name = self._name + "_IndicatorLeaf"
        latent_indicators = IndicatorLeaf(feed=feed,
                                          num_vars=self._num_sums,
                                          num_vals=self._max_sum_size,
                                          name=name)
        self.set_latent_indicators(latent_indicators)
        return latent_indicators

    @utils.lru_cache
    def _compute_reducible(self,
                           w_tensor,
                           latent_indicators_tensor,
                           *input_tensors,
                           weighted=True):
        """Computes a reducible ``Tensor`` so that reducing it over the last axis can be used for
        marginal inference, MPE inference and MPE path computation.

        Args:
            w_tensor (Tensor): A ``Tensor`` with the value of the weights of shape
                ``[num_sums, max_sum_size]``
            latent_indicators_tensor (Tensor): A ``Tensor`` with the value of the IndicatorLeaf corresponding to this node
                of shape ``[batch, num_sums * max_sum_size]``.
            input_tensors (tuple): A ``tuple`` of ``Tensors``s with the values of the children of
                this node.
            weighted (bool): Whether to apply the weights to the reducible values if possible.

        Returns:
            A ``Tensor`` of shape ``[batch, num_sums, max_sum_size]`` that can be used for computing
            marginal inference, MPE inference, gradients or MPE paths.
        """
        if not self._values:
            raise StructureError("%s is missing input values" % self)
        if not self._weights:
            raise StructureError("%s is missing weights" % self)

        # Prepare tensors for component-wise application of weights and IndicatorLeaf
        w_tensor, latent_indicators_tensor, reducible = self._prepare_component_wise_processing(
            w_tensor,
            latent_indicators_tensor,
            *input_tensors,
            zero_prob_val=-float('inf'))

        # Apply latent IndicatorLeaf
        if self._latent_indicators:
            reducible = utils.cwise_add(reducible, latent_indicators_tensor)

        # Apply weights
        if weighted:
            reducible = utils.cwise_add(reducible, w_tensor)

        return reducible

    @utils.docinherit(OpNode)
    @utils.lru_cache
    def _compute_out_size(self, *input_out_sizes):
        return self._num_sums

    @utils.docinherit(OpNode)
    @utils.lru_cache
    def _compute_log_value(self, w_tensor, latent_indicators_tensor,
                           *value_tensors):
        return self._reduce_marginal_inference_log(
            self._compute_reducible(w_tensor,
                                    latent_indicators_tensor,
                                    *value_tensors,
                                    weighted=True))

    def _get_differentiable_inputs(self, w_tensor, latent_indicators_tensor,
                                   *value_tensors):
        """Selects the tensors to include for a tf.custom_gradient when computing the log-value.

        Args:
            w_tensor (Tensor): A ``Tensor`` of shape [num_sums, max_sum_size] with the value of
                               the weights corresponding to this node.
            latent_indicators_tensor (Tensor): A ``Tensor`` of shape [batch, num_sums, max_sum_size] with the
                                 value of the IndicatorLeaf corresponding to this node.
`
        """
        return [w_tensor
                ] + ([latent_indicators_tensor]
                     if self._latent_indicators else []) + list(value_tensors)

    @utils.docinherit(OpNode)
    @utils.lru_cache
    def _compute_log_mpe_value(self, w_tensor, latent_indicators_tensor,
                               *value_tensors):
        return self._reduce_mpe_inference_log(
            self._compute_reducible(w_tensor,
                                    latent_indicators_tensor,
                                    *value_tensors,
                                    weighted=True))

    @utils.lru_cache
    def _compute_mpe_path_common(self,
                                 reducible_tensor,
                                 counts,
                                 w_tensor,
                                 latent_indicators_tensor,
                                 *input_tensors,
                                 sample=False,
                                 sample_prob=None,
                                 accumulate_weights_batch=False):
        """Common operations for computing the MPE path.

        Args:
            reducible_tensor (Tensor): A (weighted) ``Tensor`` of (log-)values of this node.
            counts (Tensor): A ``Tensor`` that contains the accumulated counts of the parents
                             of this node.
            w_tensor (Tensor):  A ``Tensor`` containing the (log-)value of the weights.
            latent_indicators_tensor (Tensor): A ``Tensor`` containing the (log-)value of the IndicatorLeaf.
            input_tensors (list): A list of ``Tensor``s with outputs of the child nodes.
            log (bool): Whether the computation is in log-space or not
            sample (bool): Whether to sample the 'winner' of the max or not
            sample_prob (Tensor): A scalar ``Tensor`` indicating the probability of drawing
                a sample. If a sample is drawn, the probability for each index is given by the
                (log-)normalized probability as given by ``reducible_tensor``.
        Returns:
            A ``list`` of ``tuple``s [(MPE counts, input tensor), ...] where the first corresponds
            to the Weights of this node, the second corresponds to the IndicatorLeaf and the remaining
            tuples correspond to the nodes in ``self._values``.
        """
        sample_prob = utils.maybe_first(sample_prob, self._sample_prob)
        num_samples = 1 if reducible_tensor.shape[
            self._reduce_axis] != 1 else self._num_sums
        if sample:
            max_indices = self._reduce_sample_log(reducible_tensor,
                                                  sample_prob=sample_prob,
                                                  num_samples=num_samples)
        else:
            max_indices = self._reduce_argmax(reducible_tensor,
                                              num_samples=num_samples)
        max_counts = utils.scatter_values(params=counts,
                                          indices=max_indices,
                                          num_out_cols=self._max_sum_size)
        max_counts_acc, max_counts_split = self._accumulate_and_split_to_children(
            max_counts, *input_tensors)
        if accumulate_weights_batch:
            max_counts = tf.reduce_sum(max_counts, axis=0, keepdims=False)
        return self._scatter_to_input_tensors(
            (max_counts, w_tensor),  # Weights
            (max_counts_acc, latent_indicators_tensor),  # IndicatorLeaf
            *[(t, v)
              for t, v in zip(max_counts_split, input_tensors)])  # Values

    @utils.lru_cache
    def _accumulate_and_split_to_children(self, x, *input_tensors):
        """Accumulates the values in x over the op axis. Potentially also accumulates for every
        unique input if appropriate (e.g. in SumsLayer).

        Args:
            x (Tensor): A ``Tensor`` containing the values to accumulate and split among the
                        children.
            input_tensors (tuple): A ``tuple`` of ``Tensors`` holding the value of the children's
                                   outputs. These might be used in e.g. SumsLayer to determine
                                   unique inputs so that values can be accumulated before passing
                                   them downward.
        Returns:
            A ``tuple`` of size 2 with the accumulated values and a list of accumulated values
            corresponding to each input.
        """
        if self._num_sums > 1:
            x_acc = tf.reduce_sum(x, axis=self._op_axis)
        else:
            x_acc = tf.squeeze(x, axis=self._op_axis)

        _, _, *value_sizes = self.get_input_sizes()
        return x_acc, tf.split(x_acc, value_sizes, axis=self._batch_axis + 1)

    @utils.docinherit(OpNode)
    @utils.lru_cache
    def _compute_log_mpe_path(self,
                              counts,
                              w_tensor,
                              latent_indicators_tensor,
                              *input_tensors,
                              use_unweighted=False,
                              accumulate_weights_batch=False,
                              sample=False,
                              sample_prob=None):
        weighted = not use_unweighted or any(v.node.is_var
                                             for v in self._values)
        reducible = self._compute_reducible(w_tensor,
                                            latent_indicators_tensor,
                                            *input_tensors,
                                            weighted=weighted)

        return self._compute_mpe_path_common(
            reducible,
            counts,
            w_tensor,
            latent_indicators_tensor,
            *input_tensors,
            accumulate_weights_batch=accumulate_weights_batch,
            sample=sample,
            sample_prob=sample_prob)

    @property
    def _tile_unweighted_size(self):
        return self._num_sums

    def _get_flat_value_scopes(self, weight_scopes, latent_indicators_scopes,
                               *value_scopes):
        """Get a flat representation of the value scopes per sum.

        Args:
            weight_scopes (list): A list of ``Scope``s corresponding to the weights.
            latent_indicators_scopes (list): A list of ``Scope``s corresponding to the IndicatorLeaf.
            value_scopes (tuple): A ``tuple`` of ``list``s of ``Scope``s corresponding to the
                                  scope lists of the children of this node.

        Returns:
            A tuple of flat value scopes corresponding to this node's output. The IndicatorLeaf scopes and
            the value scopes.
        """
        if not self._values:
            raise StructureError("%s is missing input values" % self)
        _, latent_indicators_scopes, *value_scopes = self._gather_input_scopes(
            weight_scopes, latent_indicators_scopes, *value_scopes)
        return list(chain.from_iterable(
            value_scopes)), latent_indicators_scopes, value_scopes

    @utils.docinherit(OpNode)
    def _compute_scope(self, weight_scopes, latent_indicators_scopes,
                       *value_scopes):
        flat_value_scopes, latent_indicators_scopes, *value_scopes = self._get_flat_value_scopes(
            weight_scopes, latent_indicators_scopes, *value_scopes)
        if self._latent_indicators:
            sublist_size = int(len(latent_indicators_scopes) / self._num_sums)
            # Divide gathered latent_indicators scopes into sublists, one per modelled Sum node.
            latent_indicators_scopes_sublists = [
                latent_indicators_scopes[i:i + sublist_size]
                for i in range(0, len(latent_indicators_scopes), sublist_size)
            ]
        return [
            Scope.merge_scopes(flat_value_scopes +
                               latent_indicators_scopes_sublists[i] if self.
                               _latent_indicators else flat_value_scopes)
            for i in range(self._num_sums)
        ]

    @utils.docinherit(OpNode)
    def _compute_valid(self, weight_scopes, latent_indicators_scopes,
                       *value_scopes):
        # If already invalid, return None
        if (any(s is None for s in value_scopes) or
            (self._latent_indicators and latent_indicators_scopes is None)):
            return None
        flat_value_scopes, latent_indicators_scopes_, *value_scopes_ = self._get_flat_value_scopes(
            weight_scopes, latent_indicators_scopes, *value_scopes)
        # IndicatorLeaf
        if self._latent_indicators:
            # Verify number of IndicatorLeaf
            if len(latent_indicators_scopes_
                   ) != len(flat_value_scopes) * self._num_sums:
                raise StructureError(
                    "Number of IndicatorLeaf (%s) and values (%s) does "
                    "not match for %s" %
                    (len(latent_indicators_scopes_),
                     len(flat_value_scopes) * self._num_sums, self))
            # Check if scope of all IndicatorLeaf is just one and the same variable
            if len(Scope.merge_scopes(
                    latent_indicators_scopes_)) > self._num_sums:
                return None
        # Check sum for completeness wrt values
        first_scope = flat_value_scopes[0]
        if any(s != first_scope for s in flat_value_scopes[1:]):
            self.info("%s is not complete with input value scopes %s", self,
                      flat_value_scopes)
            return None

        return self._compute_scope(weight_scopes, latent_indicators_scopes,
                                   *value_scopes)

    @property
    @utils.docinherit(OpNode)
    def _const_out_size(self):
        return True

    @utils.lru_cache
    def _prepare_component_wise_processing(self,
                                           w_tensor,
                                           latent_indicators_tensor,
                                           *input_tensors,
                                           zero_prob_val=0.0):
        """
        Gathers inputs and combines them so that the resulting tensor can be reduced over the
        last axis to compute the (weighted) sums.

        Args:
            w_tensor (Tensor): A ``Tensor`` with the (log-)value of the weights of this node of
                               shape [num_sums, max_sum_size]
            latent_indicators_tensor (Tensor): A ``Tensor`` with the (log-)value of the 'latent' ``IndicatorLeaf``.
            input_tensors (tuple): A tuple of ``Tensor``s  holding the value of the children of this
                                   node.
            zero_prob_val (float): The value of zero probability. This is important to know if some
                                   parts of the computation should be left out for masking.
        Returns:
            A tuple of size 3 containing: a weight ``Tensor`` that can be broadcast across sums, an
            IndicatorLeaf ``Tensor`` that can be applied component-wise to the sums and a ``Tensor`` that
            holds the unweighted values of the sum inputs of shape [batch, num_sums, max_sum_size].
        """

        w_tensor, latent_indicators_tensor, *input_tensors = self._gather_input_tensors(
            w_tensor, latent_indicators_tensor, *input_tensors)

        reducible_inputs = tf.expand_dims(tf.concat(input_tensors,
                                                    axis=self._reduce_axis -
                                                    1),
                                          axis=self._op_axis)

        w_tensor = tf.expand_dims(w_tensor, axis=self._batch_axis)
        if latent_indicators_tensor is not None:
            latent_indicators_tensor = tf.reshape(latent_indicators_tensor,
                                                  shape=(-1, self._num_sums,
                                                         self._max_sum_size))

        return w_tensor, latent_indicators_tensor, reducible_inputs

    @utils.lru_cache
    def _reduce_marginal_inference_log(self, x):
        """Reduces a tensor for marginal log inference by log(sum(exp(x), axis=reduce_axis)).

        Args:
            x (Tensor): A ``Tensor`` of shape [batch, num_sums, max_sum_size] to reduce over the
                        last axis.

        Returns:
            A ``Tensor`` reduced over the last axis.
        """
        return tf.reduce_logsumexp(x, axis=self._reduce_axis, keepdims=False)

    @utils.lru_cache
    def _reduce_mpe_inference(self, x):
        """Reduces a tensor for MPE non-log inference by max(x, axis=reduce_axis)).

        Args:
            x (Tensor): A ``Tensor`` of shape [batch, num_sums, max_sum_size] to reduce over the
                        last axis.

        Returns:
            A ``Tensor`` reduced over the last axis.
        """
        return tf.reduce_max(x, axis=self._reduce_axis, keepdims=False)

    @utils.lru_cache
    def _reduce_mpe_inference_log(self, x):
        """Reduces a tensor for MPE log inference by max(x, axis=reduce_axis).

        Args:
            x (Tensor): A ``Tensor`` of shape [batch, num_sums, max_sum_size] to reduce over the
                        last axis.

        Returns:
            A ``Tensor`` reduced over the last axis.
        """
        return self._reduce_mpe_inference(x)

    @utils.lru_cache
    def _reduce_argmax(self, x, num_samples=1):
        """Reduces a tensor by argmax(x, axis=reduce_axis)).

        Args:
            x (Tensor): A ``Tensor`` of shape [batch, num_sums, max_sum_size] to reduce over the
                        last axis.

        Returns:
            A ``Tensor`` reduced over the last axis.
        """
        if conf.argmax_zero:
            # If true, uses TensorFlow's argmax directly, yielding a bias towards the zeroth index
            argmax = tf.argmax(x, axis=self._reduce_axis)
            if num_samples == 1:
                return argmax
            return tf.tile(tf.expand_dims(argmax, axis=-1),
                           [1] * (len(argmax.shape) - 1) +
                           [self._tile_unweighted_size])

        # Return random index in case multiple values equal max
        x_max = tf.expand_dims(self._reduce_mpe_inference(x),
                               self._reduce_axis)
        x_eq_max = tf.cast(tf.equal(x, x_max), tf.float32)
        if self._masked:
            x_eq_max *= tf.expand_dims(tf.cast(self._build_mask(), tf.float32),
                                       axis=self._batch_axis)
        sample = self.multinomial_sample(tf.log(x_eq_max), num_samples)
        return sample

    @staticmethod
    def sample_and_transpose(d, sample_shape):
        sample = d.sample(sample_shape=sample_shape)
        if sample_shape == ():
            return sample
        else:
            return tf.transpose(sample,
                                list(range(1, len(sample.shape))) + [0])

    @utils.lru_cache
    def _reduce_sample_log(self, logits, sample_prob=None, num_samples=1):
        """Samples a tensor with log likelihoods, i.e. sample(x, axis=reduce_axis)).

        Args:
            logits (Tensor): A ``Tensor`` of shape [batch, num_sums, max_sum_size] to reduce over
                the last axis.
            sample_prob (Tensor or float): A ``Tensor`` or float indicating the probability of
                taking a sample.

        Returns:
            A ``Tensor`` reduced over the last axis.
        """

        # Categorical eventually uses non-log probabilities, so here we reuse as much as we can to
        # predetermine it
        def _sample():
            sample = self.multinomial_sample(logits, num_samples=num_samples)
            return sample

        def _select_sample_or_argmax(x):
            mask = tf.less(tf.random_uniform(tf.shape(x)), sample_prob)
            return tf.where(
                mask, x, self._reduce_argmax(logits, num_samples=num_samples))

        if sample_prob is not None:
            if isinstance(sample_prob, (float, int)):
                if sample_prob < 0 or sample_prob > 1:
                    raise ValueError(
                        "{}: Sample probability should be between 0 and 1. Got {} "
                        "instead.".format(self, sample_prob))
                if sample_prob != 0:
                    sample_op = _sample()
                    if sample_prob == 1.0:
                        return sample_op
                    return _select_sample_or_argmax(sample_op)

            return _select_sample_or_argmax(_sample())
        else:
            return _sample()

    @utils.lru_cache
    def multinomial_sample(self, logits, num_samples):
        shape = tf.shape(logits)
        last_dim = shape[-1]
        logits = tf.reshape(logits, (-1, last_dim))
        sample = tf.multinomial(logits, num_samples)

        if self._tile_unweighted_size == num_samples and self._max_sum_size > 1:
            shape = tf.concat((shape[:-1], [num_samples]), axis=0)
            return tf.squeeze(tf.reshape(sample, shape),
                              axis=self._reduce_axis - 1)

        return tf.reshape(sample, shape[:-1])
Beispiel #20
0
class DenseSPNGenerator:
    """Generates a dense SPN according to the algorithm described in
    Poon&Domingos UAI'11.

    Attributes:
        num_decomps (int): Number of decompositions at each level.
        num_subsets (int): Number of variable sub-sets for each decomposition.
        num_mixtures (int): Number of mixtures (sums) for each variable subset.
        input_dist (InputDist): Determines how inputs sharing the same scope
                                (for instance IVs for different values of a
                                random variable) should be included into the
                                generated structure.
        num_input_mixtures (int): Number of mixtures used for combining all
                                  inputs sharing scope when ``input_dist`` is
                                  set to ``MIXTURE``. If set to ``None``,
                                  ``num_mixtures`` is used.
        balanced (bool): Use only balanced decompositions, into subsets of
                         similar cardinality (differing by max 1).
    """

    __logger = get_logger()
    __debug1 = __logger.debug1
    __debug2 = __logger.debug2
    __debug3 = __logger.debug3

    class InputDist(Enum):
        """Determines how inputs sharing the same scope (for instance IVs for
        different values of a random variable) should be included into the
        generated structure."""

        RAW = 0
        """Each input is considered a different distribution over the scope and
        used directly instead of a mixture as an input to product nodes for
        singleton variable subsets."""

        MIXTURE = 1
        """``input_num_mixtures`` mixtures are created over all the inputs
        sharing a scope, effectively creating ``input_num_mixtures``
        distributions over the scope. These mixtures are then used as inputs
        to product nodes for singleton variable subsets."""

    class SubsetInfo:
        """Stores information about a single subset to be decomposed.

        Attributes:
            level(int): Number of the SPN layer where the subset is decomposed.
            subset(list of tuple of tuple): Subset of inputs to decompose
                                            grouped by scope.
            parents(list of Sum): List of sum nodes mixing the outputs of the
                                  generated decompositions. Should be the root
                                  node at the very top.
        """
        def __init__(self, level, subset, parents):
            self.level = level
            self.subset = subset
            self.parents = parents

    def __init__(self,
                 num_decomps,
                 num_subsets,
                 num_mixtures,
                 input_dist=InputDist.MIXTURE,
                 num_input_mixtures=None,
                 balanced=True):
        # Args
        if not isinstance(num_decomps, int) or num_decomps < 1:
            raise ValueError("num_decomps must be a positive integer")
        if not isinstance(num_subsets, int) or num_subsets < 1:
            raise ValueError("num_subsets must be a positive integer")
        if not isinstance(num_mixtures, int) or num_mixtures < 1:
            raise ValueError("num_mixtures must be a positive integer")
        if input_dist not in DenseSPNGenerator.InputDist:
            raise ValueError("Incorrect input_dist: %s", input_dist)
        if (num_input_mixtures is not None
                and (not isinstance(num_input_mixtures, int)
                     or num_input_mixtures < 1)):
            raise ValueError("num_input_mixtures must be None"
                             " or a positive integer")

        # Attributes
        self.num_decomps = num_decomps
        self.num_subsets = num_subsets
        self.num_mixtures = num_mixtures
        self.input_dist = input_dist
        self.balanced = balanced
        if num_input_mixtures is None:
            self.num_input_mixtures = num_mixtures
        else:
            self.num_input_mixtures = num_input_mixtures

        # Stirling numbers and ratios for partition sampling
        self.__stirling = utils.Stirling()

    def generate(self, *inputs, rnd=None, root_name=None):
        """Generate the SPN.

        Args:
            inputs (input_like): Inputs to the generated SPN.
            rnd (Random): Optional. A custom instance of a random number generator
                          ``random.Random`` that will be used instead of the
                          default global instance. This permits using a generator
                          with a custom state independent of the global one.
            root_name (str): Name of the root node of the generated SPN.

        Returns:
           Sum: Root node of the generated SPN.
        """
        self.__debug1(
            "Generating dense SPN (num_decomps=%s, num_subsets=%s,"
            " num_mixtures=%s, input_dist=%s, num_input_mixtures=%s)",
            self.num_decomps, self.num_subsets, self.num_mixtures,
            self.input_dist, self.num_input_mixtures)
        inputs = [Input.as_input(i) for i in inputs]
        input_set = self.__generate_set(inputs)
        self.__debug1("Found %s distinct input scopes", len(input_set))

        # Create root
        root = Sum(name=root_name)

        # Subsets left to process
        subsets = deque()
        subsets.append(
            DenseSPNGenerator.SubsetInfo(level=1,
                                         subset=input_set,
                                         parents=[root]))

        # Process subsets layer by layer
        self.__decomp_id = 1  # Id number of a decomposition, for info only
        while subsets:
            # Process whole layer (all subsets at the same level)
            level = subsets[0].level
            self.__debug1("Processing level %s", level)
            while subsets and subsets[0].level == level:
                subset = subsets.popleft()
                new_subsets = self.__add_decompositions(subset, rnd)
                for s in new_subsets:
                    subsets.append(s)

        return root

    def __generate_set(self, inputs):
        """Generate a set of inputs to the generated SPN grouped by scope.

        Args:
            inputs (list of Input): List of inputs.

        Returns:
           list of tuple of tuple: A list where each elements is a tuple of
               all inputs to the generated SPN which share the same scope.
               Each of that scopes is guaranteed to be unique. That tuple
               contains tuples ``(node, index)`` which uniquely identify
               specific inputs.
        """
        scope_dict = {}  # Dict indexed by scope

        def add_input(scope, node, index):
            try:
                # Try appending to existing scope
                scope_dict[scope].add((node, index))
            except KeyError:
                # Scope not in dict, check if it overlaps with other scopes
                for s in scope_dict:
                    if s & scope:
                        raise StructureError(
                            "Differing scopes of inputs overlap")
                # Add to dict
                scope_dict[scope] = set([(node, index)])

        # Process inputs and group by scope
        for inpt in inputs:
            node_scopes = inpt.node.get_scope()
            if inpt.indices is None:
                for index, scope in enumerate(node_scopes):
                    add_input(scope, inpt.node, index)
            else:
                for index in inpt.indices:
                    add_input(node_scopes[index], inpt.node, index)

        # Convert to hashable tuples and sort
        # Sorting might improve performance due to branch prediction
        return [tuple(sorted(i)) for i in scope_dict.values()]

    def __add_decompositions(self, subset_info: SubsetInfo,
                             rnd: random.Random):
        """Add nodes for a single subset, i.e. an instance of ``num_decomps``
        decompositions of ``subset`` into ``num_subsets`` sub-subsets with
        ``num_mixures`` mixtures per sub-subset.

        Args:
            subset_info(SubsetInfo): Info about the subset being decomposed.
            rnd (Random): A custom instance of a random number generator or
                          ``None`` if default global instance should be used.

        Returns:
            list of SubsetInfo: Info about each new generated subset, which
            requires further decomposition.
        """
        # Get subset partitions
        self.__debug3("Decomposing subset:\n%s", subset_info.subset)
        num_elems = len(subset_info.subset)
        num_subsubsets = min(num_elems,
                             self.num_subsets)  # Requested num subsets
        partitions = utils.random_partitions(subset_info.subset,
                                             num_subsubsets,
                                             self.num_decomps,
                                             balanced=self.balanced,
                                             rnd=rnd,
                                             stirling=self.__stirling)
        self.__debug2(
            "Randomized %s decompositions of a subset"
            " of %s elements into %s sets", len(partitions), num_elems,
            num_subsubsets)

        # Generate nodes for each decomposition/partition
        subsubset_infos = []
        for part in partitions:
            self.__debug2(
                "Decomposition %s: into %s subsubsets of cardinality %s",
                self.__decomp_id, len(part), [len(s) for s in part])
            self.__debug3("Decomposition %s subsubsets:\n%s", self.__decomp_id,
                          part)
            # Handle each subsubset
            sums_id = 1
            prod_inputs = []
            for subsubset in part:
                if len(subsubset) > 1:  # Decomposable further
                    # Add mixtures
                    with tf.name_scope("Sums%s.%s" %
                                       (self.__decomp_id, sums_id)):
                        sums = [
                            Sum(name="Sum%s" % (i + 1))
                            for i in range(self.num_mixtures)
                        ]
                        sums_id += 1
                    # Register the mixtures as inputs of products
                    prod_inputs.append([(s, 0) for s in sums])
                    # Generate subsubset info
                    subsubset_infos.append(
                        DenseSPNGenerator.SubsetInfo(level=subset_info.level +
                                                     1,
                                                     subset=subsubset,
                                                     parents=sums))
                else:  # Non-decomposable
                    if self.input_dist == DenseSPNGenerator.InputDist.RAW:
                        # Register the content of subset as inputs to products
                        prod_inputs.append(next(iter(subsubset)))
                    elif self.input_dist == DenseSPNGenerator.InputDist.MIXTURE:
                        # Add mixtures
                        with tf.name_scope("Sums%s.%s" %
                                           (self.__decomp_id, sums_id)):
                            sums = [
                                Sum(name="Sum%s" % (i + 1))
                                for i in range(self.num_input_mixtures)
                            ]
                            sums_id += 1
                        # Register the mixtures as inputs of products
                        prod_inputs.append([(s, 0) for s in sums])
                        # Connect inputs to mixtures
                        for s in sums:
                            s.add_values(*(list(next(iter(subsubset)))))
            # Add product nodes
            products = self.__add_products(prod_inputs)
            # Connect products to each parent Sum
            for p in subset_info.parents:
                p.add_values(*products)
            # Increment decomposition id
            self.__decomp_id += 1
        return subsubset_infos

    def __add_products(self, prod_inputs):
        """
        Add product nodes for a single decomposition and connect them to their
        input nodes.

        Args:
            prod_inputs (list of list of Node): A list of lists of nodes
                being inputs to the products, grouped by scope.

        Returns:
            list of Product: A list of product nodes.
        """
        selected = [0 for _ in prod_inputs]  # Input selected for each scope
        cont = True
        products = []
        product_num = 1
        with tf.name_scope("Products%s" % self.__decomp_id):
            while cont:
                # Add a product node
                products.append(
                    Product(*[pi[s] for (pi, s) in zip(prod_inputs, selected)],
                            name="Product%s" % product_num))
                product_num += 1
                # Increment selected
                cont = False
                for i, group in enumerate(prod_inputs):
                    if selected[i] < len(group) - 1:
                        selected[i] += 1
                        for j in range(i):
                            selected[j] = 0
                        cont = True
                        break
        return products
Beispiel #21
0
class Model(ABC):
    """An abstract class defining the interface of a model."""

    __logger = get_logger()
    __info = __logger.info

    def __init__(self):
        self._root = None

    def __repr__(self):
        return type(self).__qualname__

    @property
    def root(self):
        """OpNode: Root node of the model."""
        return self._root

    @abstractmethod
    def serialize(self, save_param_vals=True, sess=None):
        """Convert this model into a dictionary for serialization.

        Args:
            save_param_vals (bool): If ``True``, values of parameters will be
                evaluated in a session and stored. The TF variables of parameter
                nodes must already be initialized. If a valid session cannot be
                found, the parameter values will not be retrieved.
            sess (Session): Optional. Session used to retrieve parameter values.
                            If ``None``, the default session is used.

        Returns:
            dict: Dictionary with all the data to be serialized.
        """

    @abstractmethod
    def deserialize(self, data, load_param_vals=True, sess=None):
        """Initialize this model with the ``data`` dict during deserialization.

        Args:
            data (dict): Dictionary with all the data to be deserialized.
            load_param_vals (bool): If ``True``, saved values of parameters will
                                    be loaded and assigned in a session.
            sess (Session): Optional. Session used to assign parameter values.
                            If ``None``, the default session is used.
        """

    @abstractmethod
    def build():
        """Build the SPN graph of the model.

        Returns:
           Node: Root node of the generated model.
        """

    def save_to_json(self, path, pretty=False, save_param_vals=True, sess=None):
        """Saves the model to a JSON file.

        Args:
            path (str): Full path to the file.
            pretty (bool): Use pretty printing.
            save_param_vals (bool): If ``True``, values of parameters will be
                evaluated in a session and saved. The TF variables of parameter
                nodes must already be initialized. If a valid session cannot be
                found, the parameter values will not be saved.
            sess (Session): Optional. Session used to retrieve parameter values.
                            If ``None``, the default session is used.
        """
        self.__info("Saving %s to file '%s'" % (self, path))
        data = self.serialize(save_param_vals=save_param_vals, sess=sess)
        data['model_type'] = utils.type2str(type(self))
        utils.json_dump(path, data, pretty=pretty)

    @staticmethod
    def load_from_json(path, load_param_vals=True, sess=None):
        """Loads a model from a JSON file.

        Args:
            path (str): Full path to the file.
            load_param_vals (bool): If ``True``, saved values of parameters will
                                    be loaded and assigned in a session.
            sess (Session): Optional. Session used to assign parameter values.
                            If ``None``, the default session is used.

        Returns:
           Model: The model.
        """
        Model.__info("Loading model from file '%s'" % path)
        data = utils.json_load(path)
        model_type = utils.str2type(data['model_type'])
        model_instance = model_type.__new__(model_type)
        model_instance.deserialize(data, load_param_vals=load_param_vals,
                                   sess=sess)
        Model.__info("Loaded model %s" % model_instance)
        return model_instance