def train_step_mem(sess, train_op, global_step, train_step_kwargs): start_time = time.time() run_metadata = tf.RunMetadata() options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) total_loss, np_global_step = sess.run([train_op, global_step], options=options, run_metadata=run_metadata) time_elapsed = time.time() - start_time if 'should_log' in train_step_kwargs: if sess.run(train_step_kwargs['should_log']): tf.logging.info( 'global step %d: loss = %.4f (%.3f sec/step)', np_global_step, total_loss, time_elapsed) if log_memory: mem_use = mem_util.peak_memory(run_metadata)['/gpu:0']/1e6 tf.logging.info('Memory used: %.2f MB',(mem_use)) if 'should_stop' in train_step_kwargs: should_stop = sess.run(train_step_kwargs['should_stop']) else: should_stop = False return total_loss, should_stop
def train_step_mem(sess, train_op, global_step, train_step_kwargs): start_time = time.time() if log_memory: run_metadata = tf.RunMetadata() options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) else: run_metadata = None options = None total_loss, np_global_step, cur_gvs, dbg = sess.run( [train_op, global_step, grads_and_vars, dist_builder.DEBUG], options=options, run_metadata=run_metadata) time_elapsed = time.time() - start_time # graph = tf.get_default_graph() # main_labels = graph.get_tensor_by_name('SegmentationLoss/ScaledLabels:0') # label_out = sess.run(main_labels) # if len(np.unique(label_out)) != 1: # print(label_out) # import pdb; pdb.set_trace() # print(label_out) if 'should_log' in train_step_kwargs: if sess.run(train_step_kwargs['should_log']): tf.logging.info( 'global step %d: loss = %.4f (%.3f sec/step)', np_global_step, total_loss, time_elapsed) if log_memory: peaks = mem_util.peak_memory(run_metadata) for mem_use in peaks: # mem_use = mem_util.peak_memory(run_metadata)['/gpu:0']/1e6 if "/gpu" in mem_use: tf.logging.info('Memory used (%s): %.2f MB', mem_use, peaks[mem_use] / 1e6) # for m in mem: # tf.logging.info('Memory used: %.2f MB',(m)) if 'should_stop' in train_step_kwargs: should_stop = sess.run(train_step_kwargs['should_stop']) else: should_stop = False return total_loss, should_stop