def test_mpe_state(self): # Generate SPN model = spn.Poon11NaiveMixtureModel() model.build() # Add ops init = spn.initialize_weights(model.root) mpe_state_gen = spn.MPEState( value_inference_type=spn.InferenceType.MPE, log=False) mpe_state_gen_log = spn.MPEState( value_inference_type=spn.InferenceType.MPE, log=True) latent_indicators_state, = mpe_state_gen.get_state( model.root, model.latent_indicators) latent_indicators_state_log, = mpe_state_gen_log.get_state( model.root, model.latent_indicators) # Run with self.test_session() as sess: init.run() out = sess.run(latent_indicators_state, feed_dict={model.latent_indicators: [[-1, -1]]}) out_log = sess.run(latent_indicators_state_log, feed_dict={model.latent_indicators: [[-1, -1]]}) # For now we only compare the actual MPE state for input IndicatorLeaf -1 np.testing.assert_array_equal(out.ravel(), model.true_mpe_state) np.testing.assert_array_equal(out_log.ravel(), model.true_mpe_state)
def _run_op_test(self, op_fun, input_dist='RAW', node_type=None, inf_type=spn.InferenceType.MARGINAL, log=False, on_gpu=True): """Run a single test for a single op.""" # Preparations op_name = op_fun.__name__ device_name = '/gpu:0' if on_gpu else '/cpu:0' # Print print2( "--> %s: on_gpu=%s, input_dist=%s, inference=%s, node_type=%s, log=%s" % (op_name, on_gpu, input_dist, ("MPE" if inf_type == spn.InferenceType.MPE else "MARGINAL"), ("SINGLE" if node_type == spn.DenseSPNGenerator.NodeType.SINGLE else "BLOCK" if node_type == spn.DenseSPNGenerator.NodeType.BLOCK else "LAYER"), log), self.file) train_set, train_labels, test_set, test_labels = self._data_set(op_fun) # Create graph tf.reset_default_graph() with tf.device(device_name): # Create input ivs inputs_pl = spn.IVs(num_vars=196, num_vals=2) # Create dense SPN and generate TF graph for training start_time = time.time() # Generate SPN root, latent, learning, additive_smoothing, min_additive_smoothing, \ additive_smoothing_var = op_fun(inputs_pl, self.num_decomps, self.num_subsets, self.num_mixtures, self.num_input_mixtures, self.balanced, input_dist, node_type, inf_type, log) # Add Learning Ops init_weights = spn.initialize_weights(root) reset_accumulators = learning.reset_accumulators() accumulate_updates = learning.accumulate_updates() update_spn = learning.update_spn() # Generate Testing Ops mpe_state_gen = spn.MPEState( log=log, value_inference_type=spn.InferenceType.MPE) mpe_ivs, mpe_latent = mpe_state_gen.get_state( root, inputs_pl, latent) setup_time = time.time() - start_time if on_gpu: max_bytes_used_op = tf.contrib.memory_stats.MaxBytesInUse() # Get num of SPN ops spn_size = root.get_num_nodes() # Get num of graph ops tf_size = len(tf.get_default_graph().get_operations()) # Smoothing Decay for Additive Smoothing smoothing_decay = 0.2 # Run op multiple times with tf.Session(config=tf.ConfigProto( allow_soft_placement=False, log_device_placement=self.log_devs)) as sess: # Initialize weights of the SPN start_time = time.time() init_weights.run() weights_init_time = time.time() - start_time # Reset accumulators sess.run(reset_accumulators) run_times = [] # Create feed dictionary feed = {inputs_pl: train_set, latent: train_labels} # Run Training for epoch in range(self.num_epochs): start_time = time.time() # Adjust smoothing ads = max( np.exp(-epoch * smoothing_decay) * additive_smoothing, min_additive_smoothing) sess.run(additive_smoothing_var.assign(ads)) # Run accumulate_updates sess.run(accumulate_updates, feed_dict=feed) # Update weights sess.run(update_spn) # Reset accumulators sess.run(reset_accumulators) run_times.append(time.time() - start_time) if on_gpu: memory_used = sess.run(max_bytes_used_op) else: memory_used = None if self.profile: # Add additional options to trace the session execution options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata_acc_updt = tf.RunMetadata() run_metadata_spn_updt = tf.RunMetadata() run_metadata_acc_rst = tf.RunMetadata() # Run a single epoch # Run accumulate_updates sess.run(accumulate_updates, feed_dict=feed, options=options, run_metadata=run_metadata_acc_updt) # Update weights sess.run(update_spn, options=options, run_metadata=run_metadata_spn_updt) # Reset accumulators sess.run(reset_accumulators, options=options, run_metadata=run_metadata_acc_rst) # Create the Timeline object, and write it to a json file fetched_timeline_acc_updt = timeline.Timeline( run_metadata_acc_updt.step_stats) fetched_timeline_spn_updt = timeline.Timeline( run_metadata_spn_updt.step_stats) fetched_timeline_acc_rst = timeline.Timeline( run_metadata_acc_rst.step_stats) chrome_trace_acc_updt = fetched_timeline_acc_updt.generate_chrome_trace_format( ) chrome_trace_spn_updt = fetched_timeline_spn_updt.generate_chrome_trace_format( ) chrome_trace_acc_rst = fetched_timeline_acc_rst.generate_chrome_trace_format( ) if not os.path.exists(self.profiles_dir): os.makedirs(self.profiles_dir) file_name = op_name file_name += ("_GPU_" if on_gpu else "_CPU_") file_name += input_dist # "RAW" or "MIXTURE" file_name += ( "_ SINGLE" if node_type == spn.DenseSPNGenerator.NodeType.SINGLE else "_BLOCK" if node_type == spn.DenseSPNGenerator.NodeType.BLOCK else "_LAYER") file_name += ("_MPE-LOG" if log else "_MPE") if inf_type == \ spn.InferenceType.MPE else ("_MARGINAL-LOG" if log else "_MARGINAL") with open( '%s/timeline_%s_acc_updt.json' % (self.profiles_dir, file_name), 'w') as f: f.write(chrome_trace_acc_updt) with open( '%s/timeline_%s_spn_updt.json' % (self.profiles_dir, file_name), 'w') as f: f.write(chrome_trace_spn_updt) with open( '%s/timeline_%s_acc_rst.json' % (self.profiles_dir, file_name), 'w') as f: f.write(chrome_trace_acc_rst) # Run Testing mpe_latent_val = sess.run([mpe_latent], feed_dict={ inputs_pl: test_set, latent: np.ones((test_set.shape[0], 1)) * -1 }) result = (mpe_latent_val == test_labels) test_accuracy = np.sum(result) / test_labels.size # Return stats return OpTestResult( op_name, on_gpu, ("SINGLE" if node_type == spn.DenseSPNGenerator.NodeType.SINGLE else "BLOCK" if node_type == spn.DenseSPNGenerator.NodeType.BLOCK else "LAYER"), spn_size, tf_size, memory_used, input_dist, setup_time, weights_init_time, run_times, test_accuracy)