Ejemplo n.º 1
0
    def test_small_spn(self):
        num_vars = 13

        indicator_leaf = spn.IndicatorLeaf(num_vals=2, num_vars=num_vars)
        randomize = BlockRandomDecompositions(indicator_leaf, num_decomps=2)
        p0 = BlockPermuteProduct(randomize, num_factors=4)
        s0 = BlockSum(p0, num_sums_per_block=3)
        p1 = BlockPermuteProduct(s0, num_factors=2)
        s1 = BlockSum(p1, num_sums_per_block=3)
        p2 = BlockReduceProduct(s1, num_factors=2)
        m = BlockMergeDecompositions(p2, num_decomps=1)
        root = BlockRootSum(m)

        latent = root.generate_latent_indicators(name="Latent")
        spn.generate_weights(root,
                             initializer=tf.initializers.random_uniform())

        valgen = spn.LogValue()
        val = valgen.get_value(root)
        logsum = tf.reduce_logsumexp(val)

        num_possibilities = 2**num_vars
        nums = np.arange(num_possibilities).reshape((num_possibilities, 1))
        powers = 2**np.arange(num_vars).reshape((1, num_vars))
        leaf_feed = np.bitwise_and(nums, powers) // powers

        with self.test_session() as sess:
            sess.run(spn.initialize_weights(root))
            out = sess.run(
                logsum, {
                    indicator_leaf: leaf_feed,
                    latent: -np.ones((leaf_feed.shape[0], 1), dtype=np.int32)
                })

        self.assertAllClose(out, 0.0)
Ejemplo n.º 2
0
    def test_compute_value_sum(self, grid_size):
        indicator_leaf = spn.IndicatorLeaf(num_vals=2, num_vars=grid_size**2)
        convsum = ConvSums(indicator_leaf,
                           spatial_dim_sizes=[grid_size, grid_size],
                           num_channels=4)
        convsum2 = ConvSums(indicator_leaf,
                            spatial_dim_sizes=[grid_size, grid_size],
                            num_channels=4)
        dense_generator = spn.DenseSPNGenerator(
            num_mixtures=4,
            num_subsets=4,
            num_decomps=1,
            input_dist=spn.DenseSPNGenerator.InputDist.MIXTURE)
        root = dense_generator.generate(convsum, convsum2)
        spn.generate_weights(root,
                             initializer=tf.initializers.random_uniform())
        init = spn.initialize_weights(root)

        num_possibilities = 2**(grid_size**2)
        nums = np.arange(num_possibilities).reshape((num_possibilities, 1))
        powers = 2**np.arange(grid_size**2).reshape((1, grid_size**2))
        indicator_feed = np.bitwise_and(nums, powers) // powers

        value_op = spn.LogValue(spn.InferenceType.MARGINAL).get_value(root)
        value_op_sum = tf.reduce_logsumexp(value_op)

        with self.test_session() as sess:
            sess.run(init)
            root_sum = sess.run(value_op_sum,
                                feed_dict={indicator_leaf: indicator_feed})

        print(indicator_feed[:10])
        self.assertAllClose(root_sum, 0.0)
Ejemplo n.º 3
0
 def generate_random_weights(self, trainable=True):
     """
     Generates random weights for this spn.
     """
     weight_init_value = spn.ValueType.RANDOM_UNIFORM(
         self._weight_init_min, self._weight_init_max)
     spn.generate_weights(self._root,
                          init_value=weight_init_value,
                          trainable=trainable)
Ejemplo n.º 4
0
    def test_dropconnect(self):
        ivs = spn.IVs(num_vals=2, num_vars=4)
        s = spn.Sum(ivs, dropconnect_keep_prob=0.5)
        spn.generate_weights(s)
        init = spn.initialize_weights(s)

        mask = [[0., 1., 0., 1., 1., 1., 0., 1.],
                [1., 0., 0., 0., 0., 0., 1., 0.]]
        s._create_dropout_mask = MagicMock(
            return_value=tf.expand_dims(tf.log(mask), 1))

        val_op = tf.exp(s.get_log_value())

        mask = tf.constant(mask, dtype=tf.float32)
        truth = tf.reduce_mean(mask, axis=-1, keepdims=True)

        with self.test_session() as sess:
            sess.run(init)
            dropconnect_out, truth_out = sess.run(
                [val_op, truth],
                feed_dict={ivs: -np.ones((2, 4), dtype=np.int32)})

        self.assertAllClose(dropconnect_out, truth_out)
Ejemplo n.º 5
0
sum_12 = spn.Sum((indicator_leaves, [0, 1]), name="sum_12")

# Connect another two sums to indicators of the second variable
sum_21 = spn.Sum((indicator_leaves, [2, 3]), name="sum_21")
sum_22 = spn.Sum((indicator_leaves, [2, 3]), name="sum_22")

# Connect three product nodes
prod_1 = spn.Product(sum_11, sum_21, name="prod_1")
prod_2 = spn.Product(sum_11, sum_22, name="prod_2")
prod_3 = spn.Product(sum_12, sum_22, name="prod_3")

# Connect a root sum
root = spn.Sum(prod_1, prod_2, prod_3, name="root")

# Connect a latent indicator
indicator_y = root.generate_latent_indicators(
    name="indicator_y")  # Can be added manually

# Generate weights
spn.generate_weights(
    root,
    initializer=tf.initializers.random_uniform())  # Can be added manually

# Inspect
print(root.get_num_nodes())
print(root.get_scope())
print(root.is_valid())

# Visualize SPN graph
spn.display_spn_graph(root)
Ejemplo n.º 6
0
    x = spn.LocalSums(x, num_channels=64)
# Create final layer of products
full_scope_prod = spn.ConvProductsDepthwise(x,
                                            padding='wicker_top',
                                            kernel_size=2,
                                            strides=1,
                                            dilation_rate=2**stack_size)
class_roots = spn.ParallelSums(full_scope_prod, num_sums=num_classes)
root = spn.Sum(class_roots)

# Add a IndicatorLeaf node to the root as a latent class variable
class_indicators = root.generate_latent_indicators()

# Generate the weights for the SPN rooted at `root`
spn.generate_weights(root,
                     log=True,
                     initializer=tf.initializers.random_uniform())

print("SPN depth: {}".format(root.get_depth()))
print("Number of products layers: {}".format(
    root.get_num_nodes(node_type=spn.ConvProducts)))
print("Number of sums layers: {}".format(
    root.get_num_nodes(node_type=spn.LocalSums)))

# In[4]:

spn.display_tf_graph()

#
# <h3 id="Defining-the-TensorFlow-graph">Defining the TensorFlow graph<a class="anchor-link" href="#Defining-the-TensorFlow-graph">¶</a></h3><p>Now that we have defined the SPN graph we can declare the TensorFlow operations needed for training and evaluation. The <code>MPEState</code>
#  class can be used to find the MPE state of any node in the graph. In
Ejemplo n.º 7
0
    def test_generate_spn(self, num_decomps, num_subsets, num_mixtures,
                          num_input_mixtures, input_dims, input_dist, balanced,
                          node_type, log_weights):
        """A generic test for DenseSPNGenerator."""

        if input_dist == spn.DenseSPNGenerator.InputDist.RAW \
            and num_input_mixtures != 1:
            # Redundant test case, so just return
            return

        # Input parameters
        num_inputs = input_dims[0]
        num_vars = input_dims[1]
        num_vals = 2

        printc("\n- num_inputs: %s" % num_inputs)
        printc("- num_vars: %s" % num_vars)
        printc("- num_vals: %s" % num_vals)
        printc("- num_decomps: %s" % num_decomps)
        printc("- num_subsets: %s" % num_subsets)
        printc("- num_mixtures: %s" % num_mixtures)
        printc("- input_dist: %s" %
               ("MIXTURE" if input_dist
                == spn.DenseSPNGenerator.InputDist.MIXTURE else "RAW"))
        printc("- balanced: %s" % balanced)
        printc("- num_input_mixtures: %s" % num_input_mixtures)
        printc("- node_type: %s" %
               ("SINGLE" if node_type == spn.DenseSPNGenerator.NodeType.SINGLE
                else "BLOCK" if node_type
                == spn.DenseSPNGenerator.NodeType.BLOCK else "LAYER"))
        printc("- log_weights: %s" % log_weights)

        # Inputs
        inputs = [
            spn.IndicatorLeaf(num_vars=num_vars,
                              num_vals=num_vals,
                              name=("IndicatorLeaf_%d" % (i + 1)))
            for i in range(num_inputs)
        ]

        gen = spn.DenseSPNGenerator(num_decomps=num_decomps,
                                    num_subsets=num_subsets,
                                    num_mixtures=num_mixtures,
                                    input_dist=input_dist,
                                    balanced=balanced,
                                    num_input_mixtures=num_input_mixtures,
                                    node_type=node_type)

        # Generate Sub-SPNs
        sub_spns = [
            gen.generate(*inputs, root_name=("sub_root_%d" % (i + 1)))
            for i in range(3)
        ]

        # Generate random weights for the first sub-SPN
        with tf.name_scope("Weights"):
            spn.generate_weights(sub_spns[0],
                                 tf.initializers.random_uniform(0.0, 1.0),
                                 log=log_weights)

        # Initialize weights of the first sub-SPN
        sub_spn_init = spn.initialize_weights(sub_spns[0])

        # Testing validity of the first sub-SPN
        self.assertTrue(sub_spns[0].is_valid())

        # Generate value ops of the first sub-SPN
        sub_spn_v = sub_spns[0].get_value()
        sub_spn_v_log = sub_spns[0].get_log_value()

        # Generate path ops of the first sub-SPN
        sub_spn_mpe_path_gen = spn.MPEPath(log=False)
        sub_spn_mpe_path_gen_log = spn.MPEPath(log=True)
        sub_spn_mpe_path_gen.get_mpe_path(sub_spns[0])
        sub_spn_mpe_path_gen_log.get_mpe_path(sub_spns[0])
        sub_spn_path = [sub_spn_mpe_path_gen.counts[inp] for inp in inputs]
        sub_spn_path_log = [
            sub_spn_mpe_path_gen_log.counts[inp] for inp in inputs
        ]

        # Collect all weight nodes of the first sub-SPN
        sub_spn_weight_nodes = []

        def fun(node):
            if node.is_param:
                sub_spn_weight_nodes.append(node)

        spn.traverse_graph(sub_spns[0], fun=fun)

        # Generate an upper-SPN over sub-SPNs
        products_lower = []
        for sub_spn in sub_spns:
            products_lower.append([v.node for v in sub_spn.values])

        num_top_mixtures = [2, 1, 3]
        sums_lower = []
        for prods, num_top_mix in zip(products_lower, num_top_mixtures):
            if node_type == spn.DenseSPNGenerator.NodeType.SINGLE:
                sums_lower.append(
                    [spn.Sum(*prods) for _ in range(num_top_mix)])
            elif node_type == spn.DenseSPNGenerator.NodeType.BLOCK:
                sums_lower.append(
                    [spn.ParallelSums(*prods, num_sums=num_top_mix)])
            else:
                sums_lower.append([
                    spn.SumsLayer(*prods * num_top_mix,
                                  num_or_size_sums=num_top_mix)
                ])

        # Generate upper-SPN
        root = gen.generate(*list(itertools.chain(*sums_lower)),
                            root_name="root")

        # Generate random weights for the SPN
        with tf.name_scope("Weights"):
            spn.generate_weights(root,
                                 tf.initializers.random_uniform(0.0, 1.0),
                                 log=log_weights)

        # Initialize weight of the SPN
        spn_init = spn.initialize_weights(root)

        # Testing validity of the SPN
        self.assertTrue(root.is_valid())

        # Generate value ops of the SPN
        spn_v = root.get_value()
        spn_v_log = root.get_log_value()

        # Generate path ops of the SPN
        spn_mpe_path_gen = spn.MPEPath(log=False)
        spn_mpe_path_gen_log = spn.MPEPath(log=True)
        spn_mpe_path_gen.get_mpe_path(root)
        spn_mpe_path_gen_log.get_mpe_path(root)
        spn_path = [spn_mpe_path_gen.counts[inp] for inp in inputs]
        spn_path_log = [spn_mpe_path_gen_log.counts[inp] for inp in inputs]

        # Collect all weight nodes in the SPN
        spn_weight_nodes = []

        def fun(node):
            if node.is_param:
                spn_weight_nodes.append(node)

        spn.traverse_graph(root, fun=fun)

        # Create a session
        with self.test_session() as sess:
            # Initializing weights
            sess.run(sub_spn_init)
            sess.run(spn_init)

            # Generate input feed
            feed = np.array(
                list(
                    itertools.product(range(num_vals),
                                      repeat=(num_inputs * num_vars))))
            batch_size = feed.shape[0]
            feed_dict = {}
            for inp, f in zip(inputs, np.split(feed, num_inputs, axis=1)):
                feed_dict[inp] = f

            # Compute all values and paths of sub-SPN
            sub_spn_out = sess.run(sub_spn_v, feed_dict=feed_dict)
            sub_spn_out_log = sess.run(tf.exp(sub_spn_v_log),
                                       feed_dict=feed_dict)
            sub_spn_out_path = sess.run(sub_spn_path, feed_dict=feed_dict)
            sub_spn_out_path_log = sess.run(sub_spn_path_log,
                                            feed_dict=feed_dict)

            # Compute all values and paths of the complete SPN
            spn_out = sess.run(spn_v, feed_dict=feed_dict)
            spn_out_log = sess.run(tf.exp(spn_v_log), feed_dict=feed_dict)
            spn_out_path = sess.run(spn_path, feed_dict=feed_dict)
            spn_out_path_log = sess.run(spn_path_log, feed_dict=feed_dict)

            # Test if partition function of the sub-SPN and of the
            # complete SPN is 1.0
            self.assertAlmostEqual(sub_spn_out.sum(), 1.0, places=6)
            self.assertAlmostEqual(sub_spn_out_log.sum(), 1.0, places=6)
            self.assertAlmostEqual(spn_out.sum(), 1.0, places=6)
            self.assertAlmostEqual(spn_out_log.sum(), 1.0, places=6)

            # Test if the sum of counts for each value of each variable
            # (6 variables, with 2 values each) = batch-size / num-vals
            self.assertEqual(
                np.sum(np.hstack(sub_spn_out_path), axis=0).tolist(),
                [batch_size // num_vals] * num_inputs * num_vars * num_vals)
            self.assertEqual(
                np.sum(np.hstack(sub_spn_out_path_log), axis=0).tolist(),
                [batch_size // num_vals] * num_inputs * num_vars * num_vals)
            self.assertEqual(
                np.sum(np.hstack(spn_out_path), axis=0).tolist(),
                [batch_size // num_vals] * num_inputs * num_vars * num_vals)
            self.assertEqual(
                np.sum(np.hstack(spn_out_path_log), axis=0).tolist(),
                [batch_size // num_vals] * num_inputs * num_vars * num_vals)
Ejemplo n.º 8
0
    input_dist=input_dist)

# Generate a dense SPN for each class
class_roots = [
    dense_generator.generate(leaf_indicators) for _ in range(num_classes)
]

# Connect sub-SPNs to a root
root = spn.Sum(*class_roots, name="RootSum")
root = spn.convert_to_layer_nodes(root)

# Add an IVs node to the root as a latent class variable
class_indicators = root.generate_latent_indicators()

# Generate the weights for the SPN rooted at `root`
spn.generate_weights(root)

print("SPN depth: {}".format(root.get_depth()))
print("Number of products layers: {}".format(
    root.get_num_nodes(node_type=spn.ProductsLayer)))
print("Number of sums layers: {}".format(
    root.get_num_nodes(node_type=spn.SumsLayer)))

# Op for initializing all weights
weight_init_op = spn.initialize_weights(root)
# Op for getting the log probability of the root
root_log_prob = root.get_log_value(inference_type=inference_type)

# Helper for constructing EM learning ops
em_learning = spn.GDLearning(initial_accum_value=initial_accum_value,
                             root=root,
Ejemplo n.º 9
0
def train(args):
    reporter = ExperimentLogger(args.name, args.log_base_path)
    reporter.write_hyperparameter_dict(vars(args))
    test_x, test_y, train_x, train_y, num_classes = load_data(args)

    num_rows, num_cols = train_x.shape[1:3]
    num_vars = train_x.shape[1] * train_x.shape[2]
    num_dims = train_x.shape[-1]

    train_x = np.squeeze(train_x.reshape(-1, num_vars, num_dims))
    test_x = np.squeeze(test_x.reshape(-1, num_vars, num_dims))

    train_augmented_iterator = ImageIterator(
        [train_x, train_y],
        batch_size=args.batch_size,
        shuffle=True,
        width_shift_range=args.width_shift_range,
        height_shift_range=args.height_shift_range,
        shear_range=args.shear_range,
        rotation_range=args.rotation_range,
        zoom_range=args.zoom_range,
        horizontal_flip=args.horizontal_flip,
        image_dims=(num_rows, num_cols))
    train_iterator = DataIterator([train_x, train_y],
                                  batch_size=args.batch_size,
                                  shuffle=True)
    test_iterator = DataIterator([test_x, test_y],
                                 batch_size=args.eval_batch_size,
                                 shuffle=False)

    in_var, root = build_spn(args, num_dims, num_vars, train_x, train_y)
    spn.generate_weights(root,
                         log=args.log_weights,
                         initializer=tf.initializers.random_uniform(
                             args.weight_init_min, args.weight_init_max))

    init_weights = spn.initialize_weights(root)

    correct, labels_node, loss, likelihood, update_op, pred_op, reg_loss, loss_per_sample, \
    mpe_in_var = setup_learning(args, in_var, root)

    # Set up the evaluation tasks
    def evaluate_classification(image_batch, labels_batch, epoch, step):
        feed_dict = {in_var: image_batch}
        if args.supervised:
            feed_dict[labels_node] = labels_batch
        loss_out, correct_out, likelihood_out, reg_loss_out, loss_per_sample_out = sess.run(
            [loss, correct, likelihood, reg_loss, loss_per_sample],
            feed_dict=feed_dict)
        return [loss_out, reg_loss_out, correct_out * 100, likelihood_out]

    # Set up the evaluation tasks
    def evaluate_likelihood(image_batch, labels_batch, epoch, step):
        feed_dict = {in_var: image_batch}
        if args.supervised:
            feed_dict[labels_node] = labels_batch
        return sess.run(likelihood, feed_dict=feed_dict)

    # These are default evaluation metrics to be measured at the end of each epoch
    metrics = ["loss", "reg_loss", "accuracy", 'likelihood']
    gm_default = GroupedMetrics(
        reporter=reporter,
        reduce_fun=np.mean if not args.novelty_detection else 'roc')

    if args.supervised:

        gm_default.add_task('test_epoch.csv',
                            fun=evaluate_classification,
                            iterator=test_iterator,
                            metric_names=metrics,
                            desc="Evaluate test ",
                            batch_size=args.eval_batch_size)
        gm_default.add_task('train_epoch.csv',
                            fun=evaluate_classification,
                            iterator=train_iterator,
                            metric_names=metrics,
                            desc="Evaluate train",
                            batch_size=args.eval_batch_size,
                            return_val=True)
    else:
        gm_default.add_task('test_epoch.csv',
                            fun=evaluate_likelihood,
                            iterator=test_iterator,
                            metric_names=['likelihood'],
                            desc="Likelihood test ",
                            batch_size=args.eval_batch_size)
        gm_default.add_task('train_epoch.csv',
                            fun=evaluate_likelihood,
                            iterator=train_iterator,
                            metric_names=['likelihood'],
                            desc="Likelihood train",
                            batch_size=args.eval_batch_size,
                            return_val=True)
    if args.completion:

        with tf.name_scope("CompletionSummary"):
            truth = in_var.feed if not args.discrete else tf.placeholder(
                tf.float32, [None, num_vars])
            completion_indices = tf.equal(
                in_var.feed, -1) if args.discrete else tf.logical_not(
                    in_var.evidence)
            shape = (-1, num_rows, num_cols, num_dims)
            mosaic = impainting_mosaic(
                reconstruction=tf.reshape(mpe_in_var, shape),
                truth=tf.reshape(truth, shape),
                completion_indices=tf.reshape(completion_indices, shape),
                num_rows=4,
                batch_size=args.completion_batch_size,
                invert=args.dataset == "mnist")
            mosaic_summary = tf.summary.image("Completion", mosaic)

        def completion_left(image_batch, labels_batch, epoch, step):
            shape = [len(image_batch), num_rows, num_cols // 2]
            if np.prod(shape) > image_batch.size:
                shape.append(3)
            completion_ind = np.concatenate([np.ones(shape), np.zeros(shape)], axis=2) \
                .astype(np.bool)
            evidence_ind = np.logical_not(completion_ind)
            evidence_ind = np.reshape(evidence_ind, image_batch.shape[:2])
            completion_ind = np.reshape(completion_ind, image_batch.shape[:2])
            return _measure_completion(completion_ind,
                                       epoch,
                                       evidence_ind,
                                       image_batch,
                                       labels_batch,
                                       step,
                                       tag='left',
                                       writer=test_writer_left)

        def completion_bottom(image_batch, labels_batch, epoch, step):
            shape = [len(image_batch), num_rows // 2, num_cols]
            if np.prod(shape) > image_batch.size:
                shape.append(3)
            completion_ind = np.concatenate([np.zeros(shape), np.ones(shape)], axis=1) \
                .astype(np.bool)
            evidence_ind = np.logical_not(completion_ind)
            evidence_ind = np.reshape(evidence_ind, image_batch.shape[:2])
            completion_ind = np.reshape(completion_ind, image_batch.shape[:2])

            return _measure_completion(completion_ind,
                                       epoch,
                                       evidence_ind,
                                       image_batch,
                                       labels_batch,
                                       step,
                                       tag='bottom',
                                       writer=test_writer_bottom)

        def _measure_completion(completion_ind,
                                epoch,
                                evidence_ind,
                                image_batch,
                                labels_batch,
                                step,
                                writer,
                                tag="comp"):
            if args.discrete:
                im = image_batch.copy()
                im[completion_ind] = -1
                feed_dict = {in_var: im}
            else:
                feed_dict = {
                    in_var: image_batch,
                    in_var.evidence: evidence_ind
                }
            if args.supervised:
                feed_dict[labels_node] = labels_batch
            if step == 0:
                if args.discrete:
                    feed_dict[truth] = image_batch
                mpe_in_var_out, mosaic_summary_out, mosaic_out = sess.run(
                    [mpe_in_var, mosaic_summary, mosaic], feed_dict=feed_dict)
                writer.add_summary(mosaic_summary_out, epoch)
                reporter.write_image(
                    np.squeeze(mosaic_out, axis=0),
                    'completion/epoch_{}_{}.png'.format(epoch, tag))
            else:
                mpe_in_var_out = sess.run(mpe_in_var, feed_dict=feed_dict)

            if not args.normalize_data:
                mpe_in_var_out *= 255
                orig = image_batch.copy() * 255
            else:
                orig = image_batch

            hamming = np.equal(orig, mpe_in_var_out)[completion_ind]
            max_fluctuation = args.num_vals**2 if args.discrete else 1.0
            l2 = np.square(orig - mpe_in_var_out)[completion_ind]
            l1 = np.abs(orig - mpe_in_var_out)[completion_ind]
            hamming = np.mean(hamming.reshape(len(orig), -1), axis=-1)
            l2 = np.mean(l2.reshape(len(orig), -1), axis=-1)
            l1 = np.mean(l1.reshape(len(orig), -1), axis=-1)
            psnr = 10 * np.log10(max_fluctuation / l2)
            tv = tv_norm(
                mpe_in_var_out.reshape(-1, num_rows, num_cols, num_dims),
                completion_ind.reshape(-1, num_rows, num_cols, num_dims))
            return l1, l2, hamming, psnr, tv

        gm_default.add_task(
            "test_epoch.csv",
            fun=completion_bottom,
            iterator=test_iterator,
            metric_names=['l1_b', 'l2_b', 'hamming_b', 'psnr_b', 'tv_b'],
            desc='Completion bottom',
            batch_size=args.completion_batch_size)
        gm_default.add_task(
            "test_epoch.csv",
            fun=completion_left,
            iterator=test_iterator,
            metric_names=['l1_l', 'l2_l', 'hamming_l', 'psnr_l', 'tv_l'],
            desc='Completion left  ',
            batch_size=args.completion_batch_size)

    # Reporting total number of trainable variables
    trainable_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
    total_trainable_parameters = sum(
        [np.prod(v.shape.as_list()) for v in trainable_vars])
    reporter.write_line('num_trainable_parameters.csv',
                        total_trainable_parameters=total_trainable_parameters)
    logger.info(
        "Num trainable parameters = {}".format(total_trainable_parameters))

    # Remember five last metrics for determining stop criterion
    progress_history = deque(maxlen=5)
    progress_metric = 'loss' if args.supervised else 'likelihood'
    with tf.Session() as sess:
        train_writer, test_writer = initialize_graph(init_weights, reporter,
                                                     sess)
        test_writer_left = reporter.tfwriter('test',
                                             'completion',
                                             'left',
                                             exist_ok=True)
        test_writer_bottom = reporter.tfwriter('test',
                                               'completion',
                                               'bottom',
                                               exist_ok=True)

        progress_prev = gm_default.evaluate_one_epoch(0)[progress_metric]
        progress_history.append(progress_prev)
        for epoch in range(args.num_epochs):
            # Train, nothing more nothing less
            for image_batch, labels_batch in train_augmented_iterator.iter_epoch(
                    "Train"):
                if args.input_dropout:
                    dropout_mask = np.less(
                        np.random.rand(*image_batch.shape[:2]),
                        args.input_dropout)
                    if args.discrete:
                        image_batch_copy = image_batch.copy()
                        image_batch_copy[dropout_mask] = -1
                        feed_dict = {in_var: image_batch_copy}
                    else:
                        feed_dict = {
                            in_var: image_batch,
                            in_var.evidence: np.logical_not(dropout_mask)
                        }
                else:
                    feed_dict = {in_var: image_batch}

                if args.supervised:
                    feed_dict[labels_node] = labels_batch
                sess.run(update_op, feed_dict=feed_dict)

            # Check stopping criterion
            progress_epoch = gm_default.evaluate_one_epoch(epoch +
                                                           1)[progress_metric]
            progress_history.append(progress_epoch)
            if len(progress_history) == 5 and np.std(progress_history) < args.stop_epsilon or \
                np.isnan(progress_epoch) or progress_epoch == 0.0:
                print("Stopping criterion reached!")
                break

        # Store locations and scales
        if not args.discrete:
            loc, scale = sess.run([in_var.loc_variable, in_var.scale_variable])
            reporter.write_numpy(loc, "dist/loc")
            reporter.write_numpy(scale, "dist/scale")
            print("Locations\n", np.unique(loc))
            print("\nScales:\n", np.unique(scale))
Ejemplo n.º 10
0
    def test_compare_manual_conv(self, log_weights, inference_type):
        spn.conf.argmax_zero = True
        spatial_dims = [4, 4]
        nrows, ncols = spatial_dims
        num_vals = 4
        batch_size = 128
        num_vars = spatial_dims[0] * spatial_dims[1]
        indicator_leaf = spn.IndicatorLeaf(num_vars=num_vars,
                                           num_vals=num_vals)
        num_sums = 32
        weights = spn.Weights(num_weights=num_vals,
                              num_sums=num_sums,
                              initializer=tf.initializers.random_uniform(),
                              log=log_weights)

        parsums = []
        for row in range(nrows):
            for col in range(ncols):
                indices = list(
                    range(row * (ncols * num_vals) + col * num_vals,
                          row * (ncols * num_vals) + (col + 1) * num_vals))
                parsums.append(
                    spn.ParallelSums((indicator_leaf, indices),
                                     num_sums=num_sums,
                                     weights=weights))

        convsum = spn.ConvSums(indicator_leaf,
                               num_channels=num_sums,
                               weights=weights,
                               spatial_dim_sizes=spatial_dims)

        dense_gen = spn.DenseSPNGenerator(
            num_decomps=1,
            num_mixtures=2,
            num_subsets=2,
            input_dist=spn.DenseSPNGenerator.InputDist.RAW,
            node_type=spn.DenseSPNGenerator.NodeType.BLOCK)

        rnd = random.Random(1234)
        rnd_state = rnd.getstate()
        conv_root = dense_gen.generate(convsum, rnd=rnd)
        rnd.setstate(rnd_state)

        parsum_concat = spn.Concat(*parsums, name="ParSumConcat")
        parsum_root = dense_gen.generate(parsum_concat, rnd=rnd)

        self.assertTrue(conv_root.is_valid())
        self.assertTrue(parsum_root.is_valid())

        self.assertAllEqual(parsum_concat.get_scope(), convsum.get_scope())

        spn.generate_weights(conv_root, log=log_weights)
        spn.generate_weights(parsum_root, log=log_weights)

        convsum.set_weights(weights)
        [p.set_weights(weights) for p in parsums]

        init_conv = spn.initialize_weights(conv_root)
        init_parsum = spn.initialize_weights(parsum_root)

        path_conv = spn.MPEPath(value_inference_type=inference_type)
        path_conv.get_mpe_path(conv_root)

        path_parsum = spn.MPEPath(value_inference_type=inference_type)
        path_parsum.get_mpe_path(parsum_root)

        indicator_leaf_count_parsum = path_parsum.counts[indicator_leaf]
        indicator_leaf_count_convsum = path_conv.counts[indicator_leaf]

        weight_counts_parsum = path_parsum.counts[weights]
        weight_counts_conv = path_conv.counts[weights]

        root_val_parsum = path_parsum.value.values[parsum_root]
        root_val_conv = path_conv.value.values[conv_root]

        parsum_counts = path_parsum.counts[parsum_concat]
        conv_counts = path_conv.counts[convsum]

        indicator_feed = np.random.randint(2, size=batch_size * num_vars)\
            .reshape((batch_size, num_vars))
        with tf.Session() as sess:
            sess.run([init_conv, init_parsum])
            indicator_counts_conv_out, indicator_count_parsum_out = sess.run(
                [indicator_leaf_count_convsum, indicator_leaf_count_parsum],
                feed_dict={indicator_leaf: indicator_feed})

            root_conv_value_out, root_parsum_value_out = sess.run(
                [root_val_conv, root_val_parsum],
                feed_dict={indicator_leaf: indicator_feed})

            weight_counts_conv_out, weight_counts_parsum_out = sess.run(
                [weight_counts_conv, weight_counts_parsum],
                feed_dict={indicator_leaf: indicator_feed})

            weight_value_conv_out, weight_value_parsum_out = sess.run([
                convsum.weights.node.variable, parsums[0].weights.node.variable
            ])

            parsum_counts_out, conv_counts_out = sess.run(
                [parsum_counts, conv_counts],
                feed_dict={indicator_leaf: indicator_feed})

            parsum_concat_val, convsum_val = sess.run(
                [
                    path_parsum.value.values[parsum_concat],
                    path_conv.value.values[convsum]
                ],
                feed_dict={indicator_leaf: indicator_feed})

        self.assertTrue(np.all(np.less_equal(convsum_val, 0.0)))
        self.assertTrue(np.all(np.less_equal(parsum_concat_val, 0.0)))
        self.assertAllClose(weight_value_conv_out, weight_value_parsum_out)
        self.assertAllClose(root_conv_value_out, root_parsum_value_out)
        self.assertAllClose(indicator_counts_conv_out,
                            indicator_count_parsum_out)
        self.assertAllClose(parsum_counts_out, conv_counts_out)
        self.assertAllClose(weight_counts_conv_out, weight_counts_parsum_out)
Ejemplo n.º 11
0
    def test_compare_manual_conv(self, log_weights, inference_type):
        spn.conf.argmax_zero = True
        grid_dims = [2, 2]
        nrows, ncols = grid_dims
        num_vals = 4
        batch_size = 256
        num_vars = grid_dims[0] * grid_dims[1]
        indicator_leaf = spn.IndicatorLeaf(num_vars=num_vars,
                                           num_vals=num_vals)
        num_sums = 32
        weights = spn.Weights(num_weights=num_vals,
                              num_sums=num_sums * num_vars,
                              initializer=tf.initializers.random_uniform(),
                              log=log_weights)

        weights_per_cell = tf.split(weights.variable,
                                    num_or_size_splits=num_vars)

        parsums = []
        for row in range(nrows):
            for col in range(ncols):
                indices = list(
                    range(row * (ncols * num_vals) + col * num_vals,
                          row * (ncols * num_vals) + (col + 1) * num_vals))
                parsums.append(
                    spn.ParallelSums((indicator_leaf, indices),
                                     num_sums=num_sums))

        parsum_concat = spn.Concat(*parsums, name="ParSumConcat")
        convsum = spn.LocalSums(indicator_leaf,
                                num_channels=num_sums,
                                weights=weights,
                                spatial_dim_sizes=grid_dims)

        prod00_conv = spn.PermuteProducts(
            (convsum, list(range(num_sums))),
            (convsum, list(range(num_sums, num_sums * 2))),
            name="Prod00")
        prod01_conv = spn.PermuteProducts(
            (convsum, list(range(num_sums * 2, num_sums * 3))),
            (convsum, list(range(num_sums * 3, num_sums * 4))),
            name="Prod01")
        sum00_conv = spn.ParallelSums(prod00_conv, num_sums=2)
        sum01_conv = spn.ParallelSums(prod01_conv, num_sums=2)

        prod10_conv = spn.PermuteProducts(sum00_conv,
                                          sum01_conv,
                                          name="Prod10")

        conv_root = spn.Sum(prod10_conv)

        prod00_pars = spn.PermuteProducts(
            (parsum_concat, list(range(num_sums))),
            (parsum_concat, list(range(num_sums, num_sums * 2))))
        prod01_pars = spn.PermuteProducts(
            (parsum_concat, list(range(num_sums * 2, num_sums * 3))),
            (parsum_concat, list(range(num_sums * 3, num_sums * 4))))

        sum00_pars = spn.ParallelSums(prod00_pars, num_sums=2)
        sum01_pars = spn.ParallelSums(prod01_pars, num_sums=2)

        prod10_pars = spn.PermuteProducts(sum00_pars, sum01_pars)

        parsum_root = spn.Sum(prod10_pars)

        node_pairs = [(sum00_conv, sum00_pars), (sum01_conv, sum01_pars),
                      (conv_root, parsum_root)]

        self.assertTrue(conv_root.is_valid())
        self.assertTrue(parsum_root.is_valid())

        self.assertAllEqual(parsum_concat.get_scope(), convsum.get_scope())

        spn.generate_weights(conv_root,
                             log=log_weights,
                             initializer=tf.initializers.random_uniform())
        spn.generate_weights(parsum_root,
                             log=log_weights,
                             initializer=tf.initializers.random_uniform())

        convsum.set_weights(weights)
        copy_weight_ops = []
        parsum_weight_nodes = []
        for p, w in zip(parsums, weights_per_cell):
            copy_weight_ops.append(tf.assign(p.weights.node.variable, w))
            parsum_weight_nodes.append(p.weights.node)

        for wc, wp in node_pairs:
            copy_weight_ops.append(
                tf.assign(wp.weights.node.variable, wc.weights.node.variable))

        copy_weights_op = tf.group(*copy_weight_ops)

        init_conv = spn.initialize_weights(conv_root)
        init_parsum = spn.initialize_weights(parsum_root)

        path_conv = spn.MPEPath(value_inference_type=inference_type)
        path_conv.get_mpe_path(conv_root)

        path_parsum = spn.MPEPath(value_inference_type=inference_type)
        path_parsum.get_mpe_path(parsum_root)

        indicator_counts_parsum = path_parsum.counts[indicator_leaf]
        indicator_counts_convsum = path_conv.counts[indicator_leaf]

        weight_counts_parsum = tf.concat(
            [path_parsum.counts[w] for w in parsum_weight_nodes], axis=1)
        weight_counts_conv = path_conv.counts[weights]

        weight_parsum_concat = tf.concat(
            [w.variable for w in parsum_weight_nodes], axis=0)

        root_val_parsum = parsum_root.get_log_value(
        )  #path_parsum.value.values[parsum_root]
        root_val_conv = conv_root.get_log_value(
        )  #path_conv.value.values[conv_root]

        parsum_counts = path_parsum.counts[parsum_concat]
        conv_counts = path_conv.counts[convsum]

        indicator_feed = np.random.randint(-1, 2, size=batch_size * num_vars)\
            .reshape((batch_size, num_vars))
        with tf.Session() as sess:
            sess.run([init_conv, init_parsum])
            sess.run(copy_weights_op)
            indicator_counts_conv_out, indicator_counts_parsum_out = sess.run(
                [indicator_counts_convsum, indicator_counts_parsum],
                feed_dict={indicator_leaf: indicator_feed})

            root_conv_value_out, root_parsum_value_out = sess.run(
                [root_val_conv, root_val_parsum],
                feed_dict={indicator_leaf: indicator_feed})

            weight_counts_conv_out, weight_counts_parsum_out = sess.run(
                [weight_counts_conv, weight_counts_parsum],
                feed_dict={indicator_leaf: indicator_feed})

            weight_value_conv_out, weight_value_parsum_out = sess.run(
                [convsum.weights.node.variable, weight_parsum_concat])

            parsum_counts_out, conv_counts_out = sess.run(
                [parsum_counts, conv_counts],
                feed_dict={indicator_leaf: indicator_feed})

            parsum_concat_val, convsum_val = sess.run(
                [
                    path_parsum.value.values[parsum_concat],
                    path_conv.value.values[convsum]
                ],
                feed_dict={indicator_leaf: indicator_feed})

        self.assertAllClose(convsum_val, parsum_concat_val)
        self.assertAllClose(weight_value_conv_out, weight_value_parsum_out)
        self.assertAllClose(root_conv_value_out, root_parsum_value_out)
        self.assertAllEqual(indicator_counts_conv_out,
                            indicator_counts_parsum_out)
        self.assertAllEqual(parsum_counts_out, conv_counts_out)
        self.assertAllEqual(weight_counts_conv_out, weight_counts_parsum_out)
Ejemplo n.º 12
0
    def test_compute_dense_gen_two_spatial_decomps_v2(self, node_type,
                                                      input_dist):
        input_channels = 2
        grid_dims = [32, 32]
        num_vars = grid_dims[0] * grid_dims[1]
        vars = spn.IndicatorLeaf(num_vars=num_vars, num_vals=input_channels)

        convert_after = False
        if input_dist == spn.DenseSPNGenerator.InputDist.RAW and \
                node_type in [spn.DenseSPNGenerator.NodeType.BLOCK,
                              spn.DenseSPNGenerator.NodeType.LAYER]:
            node_type = spn.DenseSPNGenerator.NodeType.SINGLE
            convert_after = True

        # First decomposition
        convprod_dilate0 = spn.ConvProducts(vars,
                                            spatial_dim_sizes=grid_dims,
                                            num_channels=16,
                                            padding='valid',
                                            dilation_rate=2,
                                            strides=1,
                                            kernel_size=2)
        convprod_dilate1 = spn.ConvProducts(convprod_dilate0,
                                            spatial_dim_sizes=[30, 30],
                                            num_channels=512,
                                            padding='valid',
                                            dilation_rate=1,
                                            strides=4,
                                            kernel_size=2)
        convsum_dilate = spn.ConvSums(convprod_dilate1,
                                      num_channels=2,
                                      spatial_dim_sizes=[8, 8])

        # Second decomposition
        convprod_stride0 = spn.ConvProducts(vars,
                                            spatial_dim_sizes=grid_dims,
                                            num_channels=16,
                                            padding='valid',
                                            dilation_rate=1,
                                            strides=2,
                                            kernel_size=2)
        convprod_stride1 = spn.ConvProducts(convprod_stride0,
                                            spatial_dim_sizes=[16, 16],
                                            num_channels=512,
                                            padding='valid',
                                            dilation_rate=1,
                                            strides=2,
                                            kernel_size=2)
        convsum_stride = spn.ConvSums(convprod_stride1,
                                      num_channels=2,
                                      spatial_dim_sizes=[8, 8])

        # First decomposition level 2
        convprod_dilate0_l2 = spn.ConvProducts(convsum_stride,
                                               convsum_dilate,
                                               spatial_dim_sizes=[8, 8],
                                               num_channels=512,
                                               padding='valid',
                                               dilation_rate=2,
                                               strides=1,
                                               kernel_size=2)
        convprod_dilate1_l2 = spn.ConvProducts(convprod_dilate0_l2,
                                               spatial_dim_sizes=[6, 6],
                                               num_channels=512,
                                               padding='valid',
                                               dilation_rate=1,
                                               kernel_size=2,
                                               strides=4)
        convsum_dilate_l2 = spn.ConvSums(convprod_dilate1_l2,
                                         num_channels=2,
                                         spatial_dim_sizes=[4, 4])

        # Second decomposition level 2
        convprod_stride0_l2 = spn.ConvProducts(convsum_stride,
                                               convsum_dilate,
                                               spatial_dim_sizes=[8, 8],
                                               num_channels=512,
                                               padding='valid',
                                               dilation_rate=1,
                                               strides=2,
                                               kernel_size=2)
        convprod_stride1_l2 = spn.ConvProducts(convprod_stride0_l2,
                                               spatial_dim_sizes=[4, 4],
                                               num_channels=512,
                                               padding='valid',
                                               dilation_rate=1,
                                               strides=2,
                                               kernel_size=2)
        convsum_stride_l2 = spn.ConvSums(convprod_stride1_l2,
                                         num_channels=2,
                                         spatial_dim_sizes=[4, 4])

        dense_gen = spn.DenseSPNGenerator(num_mixtures=2,
                                          num_decomps=1,
                                          num_subsets=2,
                                          node_type=node_type,
                                          input_dist=input_dist)
        root = dense_gen.generate(convsum_stride_l2, convsum_dilate_l2)
        if convert_after:
            root = dense_gen.convert_to_layer_nodes(root)

        # Assert valid
        self.assertTrue(root.is_valid())

        # Setup the remaining Ops
        spn.generate_weights(root)
        init = spn.initialize_weights(root)
        value_op = tf.squeeze(root.get_log_value())

        with self.test_session() as sess:
            sess.run(init)
            value_out = sess.run(
                value_op, {vars: -np.ones((1, num_vars), dtype=np.int32)})

        self.assertAllClose(value_out, 0.0)