def init_ops(self, no_mpe=False): """ Init learning ops & MPE state. """ self._init_ops_basics() if not no_mpe: print("Initializing MPE Ops...") mpe_state_gen = spn.MPEState( log=True, value_inference_type=spn.InferenceType.MPE) if self._template_mode == NodeTemplate.code(): ## NodeTemplate if not self._expanded: self._mpe_state = mpe_state_gen.get_state( self._root, self._catg_inputs) else: self._mpe_state = mpe_state_gen.get_state( self._root, self._semantic_inputs)
weight_init_op = spn.initialize_weights(root) # Op for getting the log probability of the root root_log_prob = root.get_log_value(inference_type=inference_type) # Set up ops for discriminative GD learning gd_learning = spn.GDLearning( root=root, learning_task_type=spn.LearningTaskType.SUPERVISED, learning_method=spn.LearningMethodType.DISCRIMINATIVE) optimizer = AMSGrad(learning_rate=learning_rate) # Use post_gradients_ops = True to also normalize weights (and clip Gaussian variance) gd_update_op = gd_learning.learn(optimizer=optimizer, post_gradient_ops=True) # Compute predictions and matches mpe_state = spn.MPEState() root_marginalized = spn.Sum(root.values[0], weights=root.weights) marginalized_ivs = root_marginalized.generate_latent_indicators( feed=-tf.ones_like(class_indicators.feed)) predictions, = mpe_state.get_state(root_marginalized, marginalized_ivs) with tf.name_scope("MatchPredictionsAndTarget"): match_op = tf.equal(tf.to_int64(predictions), tf.to_int64(class_indicators.feed)) # # <h3 id="Training-the-SPN">Training the SPN<a class="anchor-link" href="#Training-the-SPN">¶</a></h3> # # In[6]: # Set up some convenient iterators
# Op for initializing all weights weight_init_op = spn.initialize_weights(root) # Op for getting the log probability of the root root_log_prob = root.get_log_value(inference_type=inference_type) # Helper for constructing EM learning ops em_learning = spn.GDLearning(initial_accum_value=initial_accum_value, root=root, value_inference_type=inference_type) # Accumulate counts and update weights online_em_update_op = em_learning.accumulate_and_update_weights() # Op for initializing accumulators init_accumulators = em_learning.reset_accumulators() # MPE state generator mpe_state_generator = spn.MPEState() # Generate MPE state ops for leaf indicator and class indicator leaf_indicator_mpe, class_indicator_mpe = mpe_state_generator.get_state( root, leaf_indicators, class_indicators) spn.display_tf_graph() # Set up some convenient iterators train_iterator = DataIterator([train_x, train_y], batch_size=batch_size) test_iterator = DataIterator([test_x, test_y], batch_size=batch_size) def fd(x, y): return {leaf_indicators: x, class_indicators: y}
def setup_learning(args, in_var, root): no_op = tf.constant(0) inference_type = spn.InferenceType.MARGINAL if args.value_inf_type == 'marginal' \ else spn.InferenceType.MPE mpe_state = spn.MPEState(value_inference_type=inference_type, matmul_or_conv=True) if args.supervised: # Root is provided with labels, p(x,y) labels_node = root.generate_latent_indicators(name="LabelIndicators") # Marginalized root, so without filling in labels, so p(x) = \sum_y p(x,y) root_marginalized = spn.Sum(*root.values, name="RootMarginalized", weights=root.weights) # A dummy node to get MPE state labels_no_evidence_node = root_marginalized.generate_latent_indicators( name="LabesNoEvidenceIndicators", feed=-tf.ones([tf.shape(in_var.feed)[0], 1], dtype=tf.int32)) # Get prediction from dummy node with tf.name_scope("Prediction"): logger.info("Setting up MPE state") if args.completion_by_marginal and isinstance( in_var, ContinuousLeafBase): in_var_mpe = in_var.impute_by_posterior_marginal( labels_no_evidence_node) class_mpe, = mpe_state.get_state(root_marginalized, labels_no_evidence_node) else: class_mpe, in_var_mpe = mpe_state.get_state( root_marginalized, labels_no_evidence_node, in_var) correct = tf.squeeze( tf.equal(class_mpe, tf.to_int64(labels_node.feed))) else: with tf.name_scope("Prediction"): class_mpe = correct = no_op labels_node = root_marginalized = None if args.completion_by_marginal and isinstance( in_var, ContinuousLeafBase): in_var_mpe = in_var.impute_by_posterior_marginal(root) else: in_var_mpe, = mpe_state.get_state(root, in_var) # Get the log likelihood with tf.name_scope("LogLikelihoods"): logger.info("Setting up log-likelihood") val_gen = spn.LogValue(inference_type=inference_type) labels_llh = val_gen.get_value(root) no_labels_llh = val_gen.get_value( root_marginalized) if args.supervised else labels_llh if args.learning_algo == "em": em_learning = spn.HardEMLearning( root, value_inference_type=inference_type, initial_accum_value=args.initial_accum_value, sample_winner=args.sample_path, sample_prob=args.sample_prob, use_unweighted=args.use_unweighted) accumulate = em_learning.accumulate_updates() with tf.control_dependencies([accumulate]): update_op = em_learning.update_spn() return correct, labels_node, labels_llh, no_labels_llh, update_op, class_mpe, no_op, \ no_op, in_var_mpe logger.info("Setting up GD learning") global_step = tf.Variable(0, trainable=False) learning_rate = tf.train.exponential_decay(args.learning_rate, global_step, args.lr_decay_steps, args.lr_decay_rate, staircase=True) learning_method = spn.LearningMethodType.DISCRIMINATIVE if args.learning_type == 'discriminative' else \ spn.LearningMethodType.GENERATIVE learning = spn.GDLearning( root, learning_task_type=spn.LearningTaskType.SUPERVISED if args.supervised else \ spn.LearningTaskType.UNSUPERVISED, learning_method=learning_method, learning_rate=learning_rate, marginalizing_root=root_marginalized, global_step=global_step) optimizer = { 'adam': tf.train.AdamOptimizer, 'rmsprop': tf.train.RMSPropOptimizer, 'amsgrad': AMSGrad, }[args.learning_algo]() minimize_op, _ = learning.learn(optimizer=optimizer) logger.info("Settting up test loss") with tf.name_scope("DeterministicLoss"): main_loss = learning.loss() regularization_loss = learning.regularization_loss() loss_per_sample = learning.loss( reduce_fn=lambda x: tf.reshape(x, (-1, ))) return correct, labels_node, main_loss, no_labels_llh, minimize_op, class_mpe, \ regularization_loss, loss_per_sample, in_var_mpe