def session_run_job(): with tf.Session() as sess: a = tf.Variable([10.0] * variable_size, name='a') b = tf.Variable([20.0] * variable_size, name='b') c = tf.Variable([30.0] * variable_size, name='c') x = tf.multiply(a, b, name="x") y = tf.add(x, c, name="y") sess.run(tf.global_variables_initializer()) run_options = tf.RunOptions() tf_debug.watch_graph( run_options, sess.graph, debug_ops="DebugIdentity(gated_grpc=True)", debug_urls=self._debugger_url) session_run_results.append(sess.run(y, options=run_options))
def _createTestGraphAndRunOptions(self, sess, gated_grpc=True): a = tf.Variable([1.0], name='a') b = tf.Variable([2.0], name='b') c = tf.Variable([3.0], name='c') d = tf.Variable([4.0], name='d') x = tf.add(a, b, name='x') y = tf.add(c, d, name='y') z = tf.add(x, y, name='z') run_options = tf.RunOptions(output_partition_graphs=True) debug_op = 'DebugIdentity' if gated_grpc: debug_op += '(gated_grpc=True)' tf_debug.watch_graph(run_options, sess.graph, debug_ops=debug_op, debug_urls=self.debug_server_url) return z, run_options
def session_run_job(): with tf.Session() as sess: a = tf.Variable(10, dtype=tf.int32, name='a') b = tf.Variable(20, dtype=tf.int32, name='b') d = tf.constant(1, dtype=tf.int32, name='d') inc_a = tf.assign_add(a, d, name='inc_a') inc_b = tf.assign_add(b, d, name='inc_b') inc_ab = tf.group([inc_a, inc_b], name="inc_ab") sess.run(tf.global_variables_initializer()) run_options = tf.RunOptions() tf_debug.watch_graph( run_options, sess.graph, debug_ops="DebugIdentity(gated_grpc=True)", debug_urls=self._debugger_url) session_run_results.append( sess.run(inc_ab, options=run_options))
def session_run_job(): with tf.Session() as sess: a = tf.Variable(10, dtype=tf.int32, name='a') b = tf.Variable(1, dtype=tf.int32, name='b') inc_a = tf.assign_add(a, b, name='inc_a') sess.run(tf.global_variables_initializer()) run_options = tf.RunOptions() tf_debug.watch_graph( run_options, sess.graph, debug_ops="DebugIdentity(gated_grpc=True)", debug_urls=self._debugger_url) for _ in range(steps): session_run_results.append( sess.run(inc_a, options=run_options))
def testMultipleInt32ValuesOverMultipleRunsAreRecorded(self): with tf.Session() as sess: x_init_val = np.array([10], dtype=np.int32) x_init = tf.constant(x_init_val, shape=[1], name="x_init") x = tf.Variable(x_init, name="x") x_inc_val = np.array([2], dtype=np.int32) x_inc = tf.constant(x_inc_val, name="x_inc") inc_x = tf.assign_add(x, x_inc, name="inc_x") sess.run(x.initializer) run_options = tf.RunOptions(output_partition_graphs=True) tf_debug.watch_graph( run_options, sess.graph, debug_ops=["DebugNumericSummary"], debug_urls=[self._debug_url], ) # Increase three times. for _ in range(3): sess.run(inc_x, options=run_options) # Debugger data is stored within a special directory within logdir. event_files = glob.glob( os.path.join( self._logdir, constants.DEBUGGER_DATA_DIRECTORY_NAME, "events.debugger*", )) self.assertEqual(1, len(event_files)) self._check_health_pills_in_events_file( event_files[0], { "x_inc:0:DebugNumericSummary": [x_inc_val] * 3, "x:0:DebugNumericSummary": [ x_init_val, x_init_val + x_inc_val, x_init_val + 2 * x_inc_val, ], }, )
def _poll_server_till_success(self, max_tries, poll_interval_seconds): for _ in range(max_tries): try: with tf.Session() as sess: a_init_val = np.array([42.0]) a_init = tf.constant(a_init_val, shape=[1], name="a_init") a = tf.Variable(a_init, name="a") run_options = tf.RunOptions(output_partition_graphs=True) tf_debug.watch_graph(run_options, sess.graph, debug_ops=["DebugNumericSummary"], debug_urls=[self._debug_url]) sess.run(a.initializer, options=run_options) return True except tf.errors.FailedPreconditionError: time.sleep(poll_interval_seconds) return False
def _createTestGraphAndRunOptions(self, sess, gated_grpc=True): a = tf.Variable([1.0], name="a") b = tf.Variable([2.0], name="b") c = tf.Variable([3.0], name="c") d = tf.Variable([4.0], name="d") x = tf.add(a, b, name="x") y = tf.add(c, d, name="y") z = tf.add(x, y, name="z") run_options = tf.compat.v1.RunOptions(output_partition_graphs=True) debug_op = "DebugIdentity" if gated_grpc: debug_op += "(gated_grpc=True)" tf_debug.watch_graph( run_options, sess.graph, debug_ops=debug_op, debug_urls=self.debug_server_url, ) return z, run_options
def testMultipleInt32ValuesOverMultipleRunsAreRecorded(self): with tf.Session() as sess: x_init_val = np.array([10], dtype=np.int32) x_init = tf.constant(x_init_val, shape=[1], name="x_init") x = tf.Variable(x_init, name="x") x_inc_val = np.array([2], dtype=np.int32) x_inc = tf.constant(x_inc_val, name="x_inc") inc_x = tf.assign_add(x, x_inc, name="inc_x") sess.run(x.initializer) run_options = tf.RunOptions(output_partition_graphs=True) tf_debug.watch_graph(run_options, sess.graph, debug_ops=["DebugNumericSummary"], debug_urls=[self._debug_url]) # Increase three times. for _ in range(3): sess.run(inc_x, options=run_options) # Debugger data is stored within a special directory within logdir. event_files = glob.glob( os.path.join(self._logdir, constants.DEBUGGER_DATA_DIRECTORY_NAME, "events.debugger*")) self.assertEqual(1, len(event_files)) self._check_health_pills_in_events_file( event_files[0], { "x_inc:0:DebugNumericSummary": [x_inc_val] * 3, "x:0:DebugNumericSummary": [ x_init_val, x_init_val + x_inc_val, x_init_val + 2 * x_inc_val], })
def testConcurrentNumericsAlertsAreRegisteredCorrectly(self): num_threads = 3 num_runs_per_thread = 2 total_num_runs = num_threads * num_runs_per_thread # Before any Session runs, the report ought to be empty. self.assertEqual([], self._debug_data_server.numerics_alert_report()) with tf.compat.v1.Session() as sess: x_init_val = np.array([[2.0], [-1.0]]) y_init_val = np.array([[0.0], [-0.25]]) z_init_val = np.array([[0.0, 3.0], [-1.0, 0.0]]) x_init = tf.constant(x_init_val, shape=[2, 1], name="x_init") x = tf.Variable(x_init, name="x") y_init = tf.constant(y_init_val, shape=[2, 1]) y = tf.Variable(y_init, name="y") z_init = tf.constant(z_init_val, shape=[2, 2]) z = tf.Variable(z_init, name="z") u = tf.compat.v1.div(x, y, name="u") # Produces an Inf. v = tf.matmul(z, u, name="v") # Produces NaN and Inf. sess.run(x.initializer) sess.run(y.initializer) sess.run(z.initializer) run_options_list = [] for i in range(num_threads): run_options = tf.compat.v1.RunOptions( output_partition_graphs=True) # Use different grpc:// URL paths so that each thread opens a separate # gRPC stream to the debug data server, simulating multi-worker setting. tf_debug.watch_graph( run_options, sess.graph, debug_ops=["DebugNumericSummary"], debug_urls=[self._debug_url + "/thread%d" % i]) run_options_list.append(run_options) def run_v(thread_id): for _ in range(num_runs_per_thread): sess.run(v, options=run_options_list[thread_id]) run_threads = [] for thread_id in range(num_threads): thread = threading.Thread( target=functools.partial(run_v, thread_id)) thread.start() run_threads.append(thread) for thread in run_threads: thread.join() report = self._debug_data_server.numerics_alert_report() self.assertEqual(2, len(report)) self.assertTrue(report[0].device_name.lower().endswith("cpu:0")) self.assertEqual("u:0", report[0].tensor_name) self.assertGreater(report[0].first_timestamp, 0) self.assertEqual(0, report[0].nan_event_count) self.assertEqual(0, report[0].neg_inf_event_count) self.assertEqual(total_num_runs, report[0].pos_inf_event_count) self.assertTrue(report[1].device_name.lower().endswith("cpu:0")) self.assertEqual("u:0", report[0].tensor_name) self.assertGreaterEqual(report[1].first_timestamp, report[0].first_timestamp) self.assertEqual(total_num_runs, report[1].nan_event_count) self.assertEqual(total_num_runs, report[1].neg_inf_event_count) self.assertEqual(0, report[1].pos_inf_event_count)
def testRunSimpleNetworkoWithInfAndNaNWorks(self): with tf.compat.v1.Session() as sess: x_init_val = np.array([[2.0], [-1.0]]) y_init_val = np.array([[0.0], [-0.25]]) z_init_val = np.array([[0.0, 3.0], [-1.0, 0.0]]) x_init = tf.constant(x_init_val, shape=[2, 1], name="x_init") x = tf.Variable(x_init, name="x") y_init = tf.constant(y_init_val, shape=[2, 1]) y = tf.Variable(y_init, name="y") z_init = tf.constant(z_init_val, shape=[2, 2]) z = tf.Variable(z_init, name="z") u = tf.compat.v1.div(x, y, name="u") # Produces an Inf. v = tf.matmul(z, u, name="v") # Produces NaN and Inf. sess.run(x.initializer) sess.run(y.initializer) sess.run(z.initializer) run_options = tf.compat.v1.RunOptions(output_partition_graphs=True) tf_debug.watch_graph(run_options, sess.graph, debug_ops=["DebugNumericSummary"], debug_urls=[self._debug_url]) result = sess.run(v, options=run_options) self.assertTrue(np.isnan(result[0, 0])) self.assertEqual(-np.inf, result[1, 0]) # Debugger data is stored within a special directory within logdir. event_files = glob.glob( os.path.join(self._logdir, constants.DEBUGGER_DATA_DIRECTORY_NAME, "events.debugger*")) self.assertEqual(1, len(event_files)) self._check_health_pills_in_events_file( event_files[0], { "x:0:DebugNumericSummary": [x_init_val], "y:0:DebugNumericSummary": [y_init_val], "z:0:DebugNumericSummary": [z_init_val], "u:0:DebugNumericSummary": [x_init_val / y_init_val], "v:0:DebugNumericSummary": [np.matmul(z_init_val, x_init_val / y_init_val)], }) report = self._debug_data_server.numerics_alert_report() self.assertEqual(2, len(report)) self.assertTrue(report[0].device_name.lower().endswith("cpu:0")) self.assertEqual("u:0", report[0].tensor_name) self.assertGreater(report[0].first_timestamp, 0) self.assertEqual(0, report[0].nan_event_count) self.assertEqual(0, report[0].neg_inf_event_count) self.assertEqual(1, report[0].pos_inf_event_count) self.assertTrue(report[1].device_name.lower().endswith("cpu:0")) self.assertEqual("u:0", report[0].tensor_name) self.assertGreaterEqual(report[1].first_timestamp, report[0].first_timestamp) self.assertEqual(1, report[1].nan_event_count) self.assertEqual(1, report[1].neg_inf_event_count) self.assertEqual(0, report[1].pos_inf_event_count)
def testConcurrentNumericsAlertsAreRegisteredCorrectly(self): num_threads = 3 num_runs_per_thread = 2 total_num_runs = num_threads * num_runs_per_thread # Before any Session runs, the report ought to be empty. self.assertEqual([], self._debug_data_server.numerics_alert_report()) with tf.Session() as sess: x_init_val = np.array([[2.0], [-1.0]]) y_init_val = np.array([[0.0], [-0.25]]) z_init_val = np.array([[0.0, 3.0], [-1.0, 0.0]]) x_init = tf.constant(x_init_val, shape=[2, 1], name="x_init") x = tf.Variable(x_init, name="x") y_init = tf.constant(y_init_val, shape=[2, 1]) y = tf.Variable(y_init, name="y") z_init = tf.constant(z_init_val, shape=[2, 2]) z = tf.Variable(z_init, name="z") u = tf.div(x, y, name="u") # Produces an Inf. v = tf.matmul(z, u, name="v") # Produces NaN and Inf. sess.run(x.initializer) sess.run(y.initializer) sess.run(z.initializer) run_options_list = [] for i in range(num_threads): run_options = tf.RunOptions(output_partition_graphs=True) # Use different grpc:// URL paths so that each thread opens a separate # gRPC stream to the debug data server, simulating multi-worker setting. tf_debug.watch_graph(run_options, sess.graph, debug_ops=["DebugNumericSummary"], debug_urls=[self._debug_url + "/thread%d" % i]) run_options_list.append(run_options) def run_v(thread_id): for _ in range(num_runs_per_thread): sess.run(v, options=run_options_list[thread_id]) # DEBUG run_threads = [] for thread_id in range(num_threads): thread = threading.Thread(target=functools.partial(run_v, thread_id)) thread.start() run_threads.append(thread) for thread in run_threads: thread.join() report = self._debug_data_server.numerics_alert_report() self.assertEqual(2, len(report)) self.assertEqual("/job:localhost/replica:0/task:0/cpu:0", report[0].device_name) self.assertEqual("u:0", report[0].tensor_name) self.assertGreater(report[0].first_timestamp, 0) self.assertEqual(0, report[0].nan_event_count) self.assertEqual(0, report[0].neg_inf_event_count) self.assertEqual(total_num_runs, report[0].pos_inf_event_count) self.assertEqual("/job:localhost/replica:0/task:0/cpu:0", report[1].device_name) self.assertEqual("u:0", report[0].tensor_name) self.assertGreaterEqual(report[1].first_timestamp, report[0].first_timestamp) self.assertEqual(total_num_runs, report[1].nan_event_count) self.assertEqual(total_num_runs, report[1].neg_inf_event_count) self.assertEqual(0, report[1].pos_inf_event_count)
def testRunSimpleNetworkoWithInfAndNaNWorks(self): with tf.Session() as sess: x_init_val = np.array([[2.0], [-1.0]]) y_init_val = np.array([[0.0], [-0.25]]) z_init_val = np.array([[0.0, 3.0], [-1.0, 0.0]]) x_init = tf.constant(x_init_val, shape=[2, 1], name="x_init") x = tf.Variable(x_init, name="x") y_init = tf.constant(y_init_val, shape=[2, 1]) y = tf.Variable(y_init, name="y") z_init = tf.constant(z_init_val, shape=[2, 2]) z = tf.Variable(z_init, name="z") u = tf.div(x, y, name="u") # Produces an Inf. v = tf.matmul(z, u, name="v") # Produces NaN and Inf. sess.run(x.initializer) sess.run(y.initializer) sess.run(z.initializer) run_options = tf.RunOptions(output_partition_graphs=True) tf_debug.watch_graph(run_options, sess.graph, debug_ops=["DebugNumericSummary"], debug_urls=[self._debug_url]) result = sess.run(v, options=run_options) self.assertTrue(np.isnan(result[0, 0])) self.assertEqual(-np.inf, result[1, 0]) # Debugger data is stored within a special directory within logdir. event_files = glob.glob( os.path.join(self._logdir, constants.DEBUGGER_DATA_DIRECTORY_NAME, "events.debugger*")) self.assertEqual(1, len(event_files)) self._check_health_pills_in_events_file(event_files[0], { "x:0:DebugNumericSummary": [x_init_val], "y:0:DebugNumericSummary": [y_init_val], "z:0:DebugNumericSummary": [z_init_val], "u:0:DebugNumericSummary": [x_init_val / y_init_val], "v:0:DebugNumericSummary": [ np.matmul(z_init_val, x_init_val / y_init_val) ], }) report = self._debug_data_server.numerics_alert_report() self.assertEqual(2, len(report)) self.assertEqual("/job:localhost/replica:0/task:0/cpu:0", report[0].device_name) self.assertEqual("u:0", report[0].tensor_name) self.assertGreater(report[0].first_timestamp, 0) self.assertEqual(0, report[0].nan_event_count) self.assertEqual(0, report[0].neg_inf_event_count) self.assertEqual(1, report[0].pos_inf_event_count) self.assertEqual("/job:localhost/replica:0/task:0/cpu:0", report[1].device_name) self.assertEqual("u:0", report[0].tensor_name) self.assertGreaterEqual(report[1].first_timestamp, report[0].first_timestamp) self.assertEqual(1, report[1].nan_event_count) self.assertEqual(1, report[1].neg_inf_event_count) self.assertEqual(0, report[1].pos_inf_event_count)
def train_original(model, config, session=None): # define a session if needed. session = session or tf.Session() # define summaries. summary_writer = tf.summary.FileWriter(config.log_dir, session.graph) image_summary = tf.summary.image( 'generated images', model.G, max_outputs=8 ) loss_summaries = tf.summary.merge([ tf.summary.scalar('wasserstein distance', -model.c_loss), tf.summary.scalar('generator loss', model.g_loss), ]) # define optimizers. C_trainer = tf.train.RMSPropOptimizer( learning_rate=config.learning_rate ) G_trainer = tf.train.RMSPropOptimizer( learning_rate=config.learning_rate ) # define parameter update tasks c_grads = C_trainer.compute_gradients(model.c_loss, var_list=model.c_vars) g_grads = G_trainer.compute_gradients(model.g_loss, var_list=model.g_vars) update_C = C_trainer.apply_gradients(c_grads) update_G = G_trainer.apply_gradients(g_grads) clip_C = [ v.assign(tf.clip_by_value(v, -config.clip_size, config.clip_size)) for v in model.c_vars ] if config.execution_graph_dump_to: import os tf.summary.FileWriter(os.path.join(os.getcwd(), config.execution_graph_dump_to), tf.get_default_graph()) exit(0) # main training session context with session: if config.resume: epoch_start = utils.load_checkpoint(session, model, config) + 1 else: epoch_start = 1 session.run(tf.global_variables_initializer()) step_counter = 0 for epoch in range(epoch_start, config.epochs+1): dataset = DATASETS[config.dataset](config.batch_size) dataset_length = DATASET_LENGTH_GETTERS[config.dataset]() dataset_stream = tqdm(dataset_length//config.batch_size) try: # while theta has not converged do while True: # for t=0,...,n_critic do for _ in range(config.critic_update_ratio): # Sample {x^(i)}[i=1,m]~Pr a batch from the real data. dataset_stream.update() xs = next(dataset) # Sample {z^(i)}[i=1,m]~p(z) a batch of prior samples. zs = _sample_z(config.batch_size, model.z_size) if config.critic_dump_to: import os from tensorflow.python import debug as tf_debug run_options = tf.RunOptions() tf_debug.watch_graph( run_options, session.graph, debug_urls=['file://' + os.path.join(os.getcwd(), config.critic_dump_to)] ) _, c_loss = session.run( [update_C, model.c_loss], feed_dict={ model.z_in: zs, model.image_in: xs }, options=run_options ) exit(0) # g_w <- grad_w[mean(f_w(x^(i))) - mean(f_w(g_theta(z^(i))))] # w <- w + alpha * RMSProp(w, g_w) _, c_loss = session.run( [update_C, model.c_loss], feed_dict={ model.z_in: zs, model.image_in: xs } ) # w <- clip(w, -c, c) session.run(clip_C) # end for # Sample {z^(i)}[i=1,m]~p(z) a batch of prior samples. zs = _sample_z(config.batch_size, model.z_size) if config.generator_dump_to: import os from tensorflow.python import debug as tf_debug run_options = tf.RunOptions() tf_debug.watch_graph( run_options, session.graph, debug_urls=['file://' + os.path.join(os.getcwd(), config.generator_dump_to)] ) _, g_loss = session.run( [update_G, model.g_loss], feed_dict={model.z_in: zs}, options=run_options ) exit(0) # g_theta <- -grad_theta[mean(f_w(g_theta(z^(i))))] # theta <- theta - alpha * RMSProp(theta, g_theta) _, g_loss = session.run( [update_G, model.g_loss], feed_dict={model.z_in: zs} ) # display current training process status step_counter += 1 dataset_stream.set_description(( 'epoch: {epoch}/{epochs} | ' 'progress: [{trained}/{total}] ({progress:.0f}%) | ' 'g loss: {g_loss:.3f} | ' 'w distance: {w_dist:.3f}' ).format( epoch=epoch, epochs=config.epochs, trained=(config.batch_size * step_counter) % dataset_length, total=dataset_length, progress=( 100. * (config.batch_size * step_counter) % dataset_length / dataset_length ), g_loss=g_loss, w_dist=-c_loss, )) # log the generated samples if step_counter % config.image_log_interval == 0: zs = _sample_z(config.sample_size, model.z_size) summary_writer.add_summary(session.run( image_summary, feed_dict={ model.z_in: zs } ), step_counter) # log the losses if step_counter % config.loss_log_interval == 0: zs = _sample_z(config.batch_size, model.z_size) summary_writer.add_summary(session.run( loss_summaries, feed_dict={ model.z_in: zs, model.image_in: xs } ), step_counter) # end while except: pass # save the model at the every end of the epochs. utils.save_checkpoint(session, model, epoch, config)
def train(model, config, session=None): # define a session if needed. session = session or tf.Session() # define summaries. summary_writer = tf.summary.FileWriter(config.log_dir, session.graph) image_summary = tf.summary.image( 'generated images', model.G, max_outputs=8 ) loss_summaries = tf.summary.merge([ tf.summary.scalar('wasserstein distance', -model.c_loss), tf.summary.scalar('generator loss', model.g_loss), ]) # define optimizers. C_traner = tf.train.AdamOptimizer( learning_rate=config.learning_rate, beta1=config.beta1 ) G_trainer = tf.train.AdamOptimizer( learning_rate=config.learning_rate, beta1=config.beta1, ) # define parameter update tasks c_grads = C_traner.compute_gradients(model.c_loss, var_list=model.c_vars) g_grads = G_trainer.compute_gradients(model.g_loss, var_list=model.g_vars) update_C = C_traner.apply_gradients(c_grads) update_G = G_trainer.apply_gradients(g_grads) clip_C = [ v.assign(tf.clip_by_value(v, -config.clip_size, config.clip_size)) for v in model.c_vars ] if config.execution_graph_dump_to: import os tf.summary.FileWriter(os.path.join(os.getcwd(), config.execution_graph_dump_to), tf.get_default_graph()) exit(0) # main training session context with session: if config.resume: epoch_start = utils.load_checkpoint(session, model, config) + 1 else: epoch_start = 1 session.run(tf.global_variables_initializer()) for epoch in range(epoch_start, config.epochs+1): dataset = DATASETS[config.dataset](config.batch_size) dataset_length = DATASET_LENGTH_GETTERS[config.dataset]() dataset_stream = tqdm(enumerate(dataset, 1)) for batch_index, xs in dataset_stream: # where are we? iteration = (epoch-1)*(dataset_length // config.batch_size) + batch_index # place more weight on ciritic in the begining of the training. critic_update_ratio = ( 30 if (batch_index < 25 or batch_index % 500 == 0) else config.critic_update_ratio ) # train the critic against the current generator and the data. for _ in range(critic_update_ratio): zs = _sample_z(config.batch_size, model.z_size) if config.critic_dump_to: import os from tensorflow.python import debug as tf_debug run_options = tf.RunOptions() tf_debug.watch_graph( run_options, session.graph, debug_urls=['file://' + os.path.join(os.getcwd(), config.critic_dump_to)] ) _, c_loss = session.run( [update_C, model.c_loss], feed_dict={ model.z_in: zs, model.image_in: xs }, options=run_options ) # session.run(clip_C, options=run_options) exit(0) _, c_loss = session.run( [update_C, model.c_loss], feed_dict={ model.z_in: zs, model.image_in: xs } ) session.run(clip_C) # train the generator against the current critic. zs = _sample_z(config.batch_size, model.z_size) if config.generator_dump_to: import os from tensorflow.python import debug as tf_debug run_options = tf.RunOptions() tf_debug.watch_graph( run_options, session.graph, debug_urls=['file://' + os.path.join(os.getcwd(), config.generator_dump_to)] ) _, g_loss = session.run( [update_G, model.g_loss], feed_dict={model.z_in: zs}, options=run_options ) exit(0) _, g_loss = session.run( [update_G, model.g_loss], feed_dict={model.z_in: zs} ) # display current training process status dataset_stream.set_description(( 'epoch: {epoch}/{epochs} | ' 'progress: [{trained}/{total}] ({progress:.0f}%) | ' 'g loss: {g_loss:.3f} | ' 'w distance: {w_dist:.3f}' ).format( epoch=epoch, epochs=config.epochs, trained=batch_index*config.batch_size, total=dataset_length, progress=( 100. * batch_index * config.batch_size / dataset_length ), g_loss=g_loss, w_dist=-c_loss, )) # log the generated samples if iteration % config.image_log_interval == 0: zs = _sample_z(config.sample_size, model.z_size) summary_writer.add_summary(session.run( image_summary, feed_dict={ model.z_in: zs } ), iteration) # log the losses if iteration % config.loss_log_interval == 0: zs = _sample_z(config.batch_size, model.z_size) summary_writer.add_summary(session.run( loss_summaries, feed_dict={ model.z_in: zs, model.image_in: xs } ), iteration) # save the model at the every end of the epochs. utils.save_checkpoint(session, model, epoch, config)