def tower_loss(scope, images, labels): """Calculate the total loss on a single tower running the CIFAR model. Args: scope: unique prefix string identifying the CIFAR tower, e.g. 'tower_0' images: Images. 4D tensor of shape [batch_size, height, width, 3]. labels: Labels. 1D tensor of shape [batch_size]. Returns: Tensor of shape [] containing the total loss for a batch of data """ # Build inference Graph. logits = cifar100.inference(images) # Build the portion of the Graph calculating the losses. Note that we will # assemble the total_loss using a custom function below. _ = cifar100.loss(logits, labels) # Assemble all of the losses for the current tower only. losses = tf.get_collection('losses', scope) # Calculate the total loss for the current tower. total_loss = tf.add_n(losses, name='total_loss') # Attach a scalar summary to all individual losses and the total loss; do the # same for the averaged version of the losses. for l in losses + [total_loss]: # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training # session. This helps the clarity of presentation on tensorboard. loss_name = re.sub('%s_[0-9]*/' % cifar100.TOWER_NAME, '', l.op.name) tf.summary.scalar(loss_name, l) return total_loss
def train(): with tf.Graph().as_default(): global_step = tf.Variable(0, trainable=False) images, labels = cifar100.distorted_inputs() logits = cifar100.inference(images) loss = cifar100.loss(logits, labels) train_op = cifar100.train(loss, global_step) saver = tf.train.Saver(tf.all_variables()) init = tf.initialize_all_variables() sess = tf.Session(config=tf.ConfigProto( log_device_placement=FLAGS.log_device_placement)) sess.run(init) tf.train.start_queue_runners(sess=sess) for step in xrange(FLAGS.max_steps): start_time = time.time() _, loss_value = sess.run([train_op, loss]) duration = time.time() - start_time assert not np.isnan(loss_value) if step % 10 == 0: num_examples_per_step = FLAGS.batch_size examples_per_sec = num_examples_per_step / duration sec_per_batch = float(duration) format_str = ( '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (datetime.now(), step, loss_value, examples_per_sec, sec_per_batch)) # Save the model checkpoint periodically. if step % 1000 == 0 or (step + 1) == FLAGS.max_steps: checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step)
def train(): """Train CIFAR-100 for a number of steps.""" with tf.Graph().as_default(): global_step = tf.contrib.framework.get_or_create_global_step() # Get images and labels for CIFAR-100. # Force input pipeline to CPU:0 to avoid operations sometimes ending up on # GPU and resulting in a slow down. with tf.device('/cpu:0'): images, labels = cifar100.distorted_inputs() # Build a Graph that computes the logits predictions from the # inference model. logits = cifar100.inference(images) # Calculate loss. loss = cifar100.loss(logits, labels) # Build a Graph that trains the model with one batch of examples and # updates the model parameters. train_op = cifar100.train(loss, global_step) class _LoggerHook(tf.train.SessionRunHook): """Logs loss and runtime.""" def begin(self): self._step = -1 self._start_time = time.time() def before_run(self, run_context): self._step += 1 return tf.train.SessionRunArgs(loss) # Asks for loss value. def after_run(self, run_context, run_values): if self._step % FLAGS.log_frequency == 0: current_time = time.time() duration = current_time - self._start_time self._start_time = current_time loss_value = run_values.results self._last_loss = loss_value examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration sec_per_batch = float(duration / FLAGS.log_frequency) format_str = ( '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (datetime.now(), self._step, loss_value, examples_per_sec, sec_per_batch)) def last_loss(self): return self._last_loss loghook = _LoggerHook() with tf.train.MonitoredTrainingSession( checkpoint_dir=FLAGS.train_dir, hooks=[ tf.train.StopAtStepHook(last_step=FLAGS.max_steps), tf.train.NanTensorHook(loss), loghook ], config=tf.ConfigProto(log_device_placement=FLAGS. log_device_placement)) as mon_sess: t1 = time.time() while not mon_sess.should_stop(): mon_sess.run(train_op) t2 = time.time() print('spent %f seconds to train %d step' % (t2 - t1, FLAGS.max_steps)) print('spent %f seconds to train %d step' % (t2 - t1, FLAGS.max_steps)) print('last loss value: %.2f ' % loghook.last_loss())
def train(): """训练cifar100""" with tf.Graph().as_default(): global_step = tf.train.get_or_create_global_step() with tf.device('/cpu:0'): images, labels = cifar100.destorted_inputs() # 建立模型,并获取得到的结果logits,用于与labels求交叉熵 logits = cifar100.inference(images) # 计算损失 loss = cifar100.loss(logits, labels) train_op = cifar100.train(loss, global_step) class _LoggerHook(tf.train.SessionRunHook): """打印损失和运行状态""" def begin(self): self._step = -1 self._start_time = time.time() def before_run(self, run_context): self._step += 1 return tf.train.SessionRunArgs(loss) def after_run(self, run_context, run_values): if self._step % FLAGS.log_frequency == 0: current_time = time.time() duration = current_time - self._start_time self._start_time = current_time loss_value = run_values.results examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration sec_per_batch = float(duration / FLAGS.log_frequency) format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (datetime.now(), self._step, loss_value, examples_per_sec, sec_per_batch)) config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement) config.gpu_options.allocator_type = 'BFC' # 使用BFC算法 config.gpu_options.per_process_gpu_memory_fraction = 0.5 # 程序最多只能占用指定gpu50%的显存 config.gpu_options.allow_growth = True # 程序按需申请内存 # MonitoredTrainingSession是一个方便的tensorflow会话初始化/恢复器, # 也可用于分布式训练 with tf.train.MonitoredTrainingSession( # 加载保存的训练状态的目录,如为空则设为保存目录 checkpoint_dir=FLAGS.train_dir, # 保存间隔 save_checkpoint_secs=None, save_checkpoint_steps=10000, # 可选的SessionRunHook对象列表 # StopAtStepHook表示停止步数 # NanTensorHook表示当loss为None时返回异常并停止训练 hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps), tf.train.NanTensorHook(loss), _LoggerHook()], config=config) as mon_sess: while not mon_sess.should_stop(): mon_sess.run(train_op)
def train(): """Train CIFAR-10 for a number of steps.""" #lanhin #Construct the cluster and start the server ps_spec = FLAGS.ps_hosts.split(",") worker_spec = FLAGS.worker_hosts.split(",") # Get the number of workers. num_workers = len(worker_spec) cluster = tf.train.ClusterSpec({"ps": ps_spec, "worker": worker_spec}) server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) if FLAGS.job_name == "ps": server.join() # only worker will do train() is_chief = False if FLAGS.task_index == 0: is_chief = True #lanhin end #with tf.Graph().as_default(): # Use comment to choose which way of tf.device() you want to use #with tf.Graph().as_default(), tf.device(tf.train.replica_device_setter( # worker_device="/job:worker/task:%d" % FLAGS.task_index, # cluster=cluster)): with tf.device("job:worker/task:%d" % FLAGS.task_index): global_step = tf.contrib.framework.get_or_create_global_step() # Get images and labels for CIFAR-10. # Force input pipeline to CPU:0 to avoid operations sometimes ending up on # GPU and resulting in a slow down. #with tf.device('/cpu:0'): #images, labels = cifar10.distorted_inputs() (x_train, y_train_orl), (x_test, y_test_orl) = dset.cifar100.load_data( label_mode='fine') x_train = x_train.astype('float32') x_test = x_test.astype('float32') x_train, x_test = normalize(x_train, x_test) y_train_orl = y_train_orl.astype('int32') y_test_orl = y_test_orl.astype('int32') y_train_flt = y_train_orl.ravel() y_test_flt = y_test_orl.ravel() x = tf.placeholder(tf.float32, shape=(FLAGS.batch_size, 32, 32, 3)) y = tf.placeholder(tf.int32, shape=(FLAGS.batch_size, )) # Build a Graph that computes the logits predictions from the # inference model. #logits, local_var_list = cifar10.inference(images) logits, local_var_list = cifar100.inference(x) # Calculate loss. #loss = cifar10.loss(logits, labels) loss = cifar100.loss(logits, y) # Build a Graph that trains the model with one batch of examples and # updates the model parameters. train_op = cifar100.train(loss, global_step) # the temp var part, for performance testing tmp_var_list = [] var_index = 0 for var in local_var_list: var_index += 1 tmp_var_list.append( tf.Variable(tf.zeros(var.shape), name="tmp_var" + str(var_index))) # the non chief workers get local var init_op here if not is_chief: init_op = tf.global_variables_initializer() else: init_op = None # start global variables region global_var_list = [] with tf.device("/job:ps/replica:0/task:0/cpu:0"): # barrier var finished = tf.get_variable("worker_finished", [], tf.int32, tf.zeros_initializer(tf.int32), trainable=False) with finished.graph.colocate_with(finished): finish_op = finished.assign_add(1, use_locking=True) var_index = 0 for var in local_var_list: var_index += 1 global_var_list.append( tf.Variable(tf.zeros(var.shape), name="glo_var" + str(var_index))) def assign_global_vars(): # assign local vars' values to global vars return [ gvar.assign(lvar) for (gvar, lvar) in zip(global_var_list, local_var_list) ] def assign_local_vars(): # assign global vars' values to local vars return [ lvar.assign(gvar) for (gvar, lvar) in zip(global_var_list, local_var_list) ] def assign_tmp_vars(): # assign local vars' values to tmp vars return [ tvar.assign(lvar) for (tvar, lvar) in zip(tmp_var_list, local_var_list) ] def assign_local_vars_from_tmp( ): # assign tmp vars' values to local vars return [ lvar.assign(tvar) for (tvar, lvar) in zip(tmp_var_list, local_var_list) ] def update_before_train(alpha, w, global_w): varib = alpha * (w - global_w) gvar_op = global_w.assign(global_w + varib) return gvar_op, varib def update_after_train(w, vab): return w.assign(w - vab) assign_list_local = assign_local_vars() assign_list_global = assign_global_vars() assign_list_loc2tmp = assign_tmp_vars() assign_list_tmp2loc = assign_local_vars_from_tmp() before_op_tuple_list = [] after_op_tuple_list = [] vbholder_list = [] for (gvar, lvar) in zip(global_var_list, local_var_list): before_op_tuple_list.append( (update_before_train(alpha, lvar, gvar))) for var in local_var_list: vbholder_list.append(tf.placeholder("float", var.shape)) after_op_tuple_list.append( (update_after_train(var, vbholder_list[-1]), vbholder_list[-1])) # the chief worker get global var init op here if is_chief: init_op = tf.global_variables_initializer() # global variables region end #lanhin start sv = tf.train.Supervisor( is_chief=True, #is_chief, logdir=FLAGS.train_dir, init_op=init_op, #local_init_op=loc_init_op, recovery_wait_secs=1) #global_step=global_step) sess_config = tf.ConfigProto( allow_soft_placement=True, log_device_placement=FLAGS.log_device_placement, device_filters=[ "/job:ps", "/job:worker/task:%d" % FLAGS.task_index ]) # The chief worker (task_index==0) session will prepare the session, # while the remaining workers will wait for the preparation to complete. if is_chief: print("Worker %d: Initializing session..." % FLAGS.task_index) else: print("Worker %d: Waiting for session to be initialized..." % FLAGS.task_index) sess = sv.prepare_or_wait_for_session(server.target, config=sess_config) if is_chief: sess.run(assign_list_global) barrier_finished = sess.run(finish_op) print("barrier_finished:", barrier_finished) else: barrier_finished = sess.run(finish_op) print("barrier_finished:", barrier_finished) while barrier_finished < num_workers: time.sleep(1) barrier_finished = sess.run(finished) sess.run(assign_list_local) print("Worker %d: Session initialization complete." % FLAGS.task_index) # lanhin end #sess = tf.Session() #sess.run(init_op) #tf.train.start_queue_runners(sess) f = open('tl_dist.json', 'w') run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() time_begin = time.time() # while not mon_sess.should_stop(): # mon_sess.run(train_op) for step in range(FLAGS.max_steps): offset = (step * FLAGS.batch_size) % (EPOCH_SIZE - FLAGS.batch_size) x_data = x_train[offset:(offset + FLAGS.batch_size), ...] y_data_flt = y_train_flt[offset:(offset + FLAGS.batch_size)] if step % FLAGS.log_frequency == 0: time_step = time.time() steps_time = time_step - time_begin print("step:", step, " steps time:", steps_time, end=' ') sess.run(assign_list_loc2tmp) sess.run(assign_list_local) predt(sess, x_test, y_test_flt, logits, x, y) sess.run(assign_list_tmp2loc) time_begin = time.time() if step % FLAGS.tau == 0 and step > 0: # update global weights thevarib_list = [] for i in range(0, len(before_op_tuple_list)): (gvar_op, varib) = before_op_tuple_list[i] _, thevarib = sess.run([gvar_op, varib]) thevarib_list.append(thevarib) sess.run(train_op, feed_dict={x: x_data, y: y_data_flt}) for i in range(0, len(after_op_tuple_list)): (lvar_op, thevaribHolder) = after_op_tuple_list[i] sess.run(lvar_op, feed_dict={thevaribHolder: thevarib_list[i]}) else: sess.run(train_op, feed_dict={ x: x_data, y: y_data_flt }) #, options=run_options, run_metadata=run_metadata) #tl = timeline.Timeline(run_metadata.step_stats) #ctf = tl.generate_chrome_trace_format() #f.write(ctf) time_end = time.time() training_time = time_end - time_begin print("Training elapsed time: %f s" % training_time) f.close() sess.run(assign_list_local) predt(sess, x_test, y_test_flt, logits, x, y)
def train(): """Train CIFAR-100 for a number of steps.""" output = open('output_data/output_' + str(time.time()) + '.txt', 'w') with tf.Graph().as_default(): global_step = tf.train.get_or_create_global_step() # Get images and labels for CIFAR-100. # Force input pipeline to CPU:0 to avoid operations sometimes ending up on # GPU and resulting in a slow down. with tf.device('/cpu:0'): images, labels = cifar100.distorted_inputs() # Build a Graph that computes the logits predictions from the # inference model. logitsA,logitsB = cifar100.inference(images) # Calculate loss. lossA = cifar100.loss(logitsA, labels) lossB = cifar100.loss(logitsB, labels) # Build a Graph that trains the model with one batch of examples and # updates the model parameters. train_opA = cifar100.train(lossA, global_step) train_opB = cifar100.train(lossB, global_step) class _LoggerHook(tf.train.SessionRunHook): """Logs loss and runtime.""" def begin(self): self._step = -1 self._start_time = time.time() def before_run(self, run_context): self._step += 1 return tf.train.SessionRunArgs(lossA) # Asks for loss value. def after_run(self, run_context, run_values): if self._step % FLAGS.log_frequency == 0: current_time = time.time() duration = current_time - self._start_time self._start_time = current_time loss_value = run_values.results examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration sec_per_batch = float(duration / FLAGS.log_frequency) format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (datetime.now(), self._step, loss_value, examples_per_sec, sec_per_batch)) print((str(self._step) + '\t' + str(loss_value) + '\n'), file=output) with tf.train.MonitoredTrainingSession( checkpoint_dir=FLAGS.train_dir, hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps), tf.train.NanTensorHook(lossA), tf.train.NanTensorHook(lossB), _LoggerHook()], config=tf.ConfigProto( log_device_placement=FLAGS.log_device_placement)) as mon_sess: file_writer = tf.summary.FileWriter('tb-logs/', mon_sess.graph) while not mon_sess.should_stop(): print("stepA") mon_sess.run(train_opA) print("stepB") mon_sess.run(train_opB) output.close()
def train(): print('FLAGS.data_dir: %s' % FLAGS.data_dir) ps_hosts = FLAGS.ps_hosts.split(",") worker_hosts = FLAGS.worker_hosts.split(",") # Create a cluster from the parameter server and worker hosts. cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts}) server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) if FLAGS.job_name == 'ps': server.join() is_chief = (FLAGS.task_index == 0) with tf.device( tf.train.replica_device_setter( worker_device="/job:worker/task:%d" % FLAGS.task_index, ps_device="/job:ps/task:0", cluster=cluster)): global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False) # Get images and labels for CIFAR-100. images, labels = cifar100.distorted_inputs() num_workers = len(worker_hosts) num_replicas_to_aggregate = num_workers logits = cifar100.inference(images) # Calculate loss. loss = cifar100.loss(logits, labels) # Retain the summaries from the chief. # Calculate the learning rate schedule. num_batches_per_epoch = (cifar100.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_size) decay_steps = int(num_batches_per_epoch * cifar100.NUM_EPOCHS_PER_DECAY) # Decay the learning rate exponentially based on the number of steps. lr = tf.train.exponential_decay(cifar100.INITIAL_LEARNING_RATE, global_step, decay_steps, cifar100.LEARNING_RATE_DECAY_FACTOR, staircase=True) if is_chief: summaries = tf.get_collection(tf.GraphKeys.SUMMARIES) # Add a summary to track the learning rate. summaries.append(tf.summary.scalar('learning_rate', lr)) # Create an optimizer that performs gradient descent. opt = tf.train.GradientDescentOptimizer(lr) opt = tf.train.SyncReplicasOptimizer( opt, replicas_to_aggregate=num_replicas_to_aggregate, total_num_replicas=num_workers, #use_locking=True) use_locking=False) # Calculate the gradients for the batch grads = opt.compute_gradients(loss) # Add histograms for gradients at the chief worker. if is_chief: for grad, var in grads: if grad is not None: summaries.append( tf.summary.histogram(var.op.name + '/gradients', grad)) # apply gradients to variable train_op = opt.apply_gradients(grads, global_step=global_step) # Add histograms for trainable variables. if is_chief: for var in tf.trainable_variables(): summaries.append(tf.summary.histogram(var.op.name, var)) #variable_averages = tf.train.ExponentialMovingAverage( # cifar100.MOVING_AVERAGE_DECAY, global_step) #variables_averages_op = variable_averages.apply(tf.trainable_variables()) #train_op = tf.group(train_op, variables_averages_op) if is_chief: #Build the summary operation at the chief worker summary_op = tf.summary.merge(summaries) chief_queue_runner = opt.get_chief_queue_runner() init_token_op = opt.get_init_tokens_op() # Build an initialization operation to run below. init_op = tf.global_variables_initializer() # Create a saver. saver = tf.train.Saver(tf.global_variables()) sv = tf.train.Supervisor(is_chief=is_chief, global_step=global_step, init_op=init_op) sess_config = tf.ConfigProto( allow_soft_placement=True, log_device_placement=FLAGS.log_device_placement) with sv.prepare_or_wait_for_session(server.target, config=sess_config) as sess: # Start running operations on the Graph. allow_soft_placement must be set to # True to build towers on GPU, as some of the ops do not have GPU # implementations. # start sync queue runner and run the init token op at the chief worker queue_runners = tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS) sv.start_queue_runners(sess, queue_runners) if is_chief: sv.start_queue_runners(sess, [chief_queue_runner]) sess.run(init_token_op) #open the summary writer summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph) t1 = time.time() for step in xrange(FLAGS.max_steps): start_time = time.time() _, loss_value = sess.run([train_op, loss]) duration = time.time() - start_time assert not np.isnan(loss_value), 'Model diverged with loss = NaN' if step % 10 == 0: num_examples_per_step = FLAGS.batch_size * num_workers examples_per_sec = num_examples_per_step / duration sec_per_batch = duration / num_workers format_str = ( '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (datetime.now(), step, loss_value, examples_per_sec, sec_per_batch)) if step % 100 == 0: if is_chief: summary_str = sess.run(summary_op) summary_writer.add_summary(summary_str, step) # Save the model checkpoint periodically. if step % 1000 == 0 or (step + 1) == FLAGS.max_steps: if is_chief: checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step) t2 = time.time() print('spent %f seconds to train %d step' % (t2 - t1, FLAGS.max_steps)) logger.info('spent %f seconds to train %d step' % (t2 - t1, FLAGS.max_steps)) logger.info('last loss value: %.2f ' % loss_value)