def train(): """Train Inception on a dataset for a number of steps.""" ps_hosts = FLAGS.ps_hosts.split(',') worker_hosts = FLAGS.worker_hosts.split(',') tf.logging.info('PS hosts are: %s' % ps_hosts) tf.logging.info('Worker hosts are: %s' % worker_hosts) cluster_spec = tf.train.ClusterSpec({ 'ps': ps_hosts, 'worker': worker_hosts }) server = tf.train.Server({ 'ps': ps_hosts, 'worker': worker_hosts }, job_name=FLAGS.job_name, task_index=FLAGS.task_id, protocol=FLAGS.protocol) batchSizeManager = BatchSizeManager(FLAGS.batch_size, len(worker_hosts)) if FLAGS.job_name == 'ps': if FLAGS.task_id == 0: rpcServer = batchSizeManager.create_rpc_server( ps_hosts[0].split(':')[0]) rpcServer.serve() server.join() dataset = ImagenetData(subset=FLAGS.subset) rpcClient = batchSizeManager.create_rpc_client(ps_hosts[0].split(':')[0]) assert dataset.data_files() # Only the chief checks for or creates train_dir. if FLAGS.task_id == 0: if not tf.gfile.Exists(FLAGS.train_dir): tf.gfile.MakeDirs(FLAGS.train_dir) num_workers = len(cluster_spec.as_dict()['worker']) num_parameter_servers = len(cluster_spec.as_dict()['ps']) if FLAGS.num_replicas_to_aggregate == -1: num_replicas_to_aggregate = num_workers else: num_replicas_to_aggregate = FLAGS.num_replicas_to_aggregate # Both should be greater than 0 in a distributed training. assert num_workers > 0 and num_parameter_servers > 0, ( ' num_workers and ' 'num_parameter_servers' ' must be > 0.') # Choose worker 0 as the chief. Note that any worker could be the chief # but there should be only one chief. is_chief = (FLAGS.task_id == 0) #batchSizeManager = BatchSizeManager(32, 4) # Ops are assigned to worker by default. tf.logging.info('cccc-num_parameter_servers:' + str(num_parameter_servers)) partitioner = tf.fixed_size_partitioner(num_parameter_servers, 0) device_setter = tf.train.replica_device_setter( ps_tasks=num_parameter_servers) slim = tf.contrib.slim with tf.device('/job:worker/task:%d' % FLAGS.task_id): with tf.variable_scope('root', partitioner=partitioner): # Variables and its related init/assign ops are assigned to ps. # with slim.arg_scope( # [slim.variables.variable, slim.variables.global_step], # device=slim.variables.VariableDeviceChooser(num_parameter_servers)): with tf.device(device_setter): # partitioner=partitioner): # Create a variable to count the number of train() calls. This equals the # number of updates applied to the variables. # global_step = slim.variables.global_step() global_step = tf.Variable(0, trainable=False) # Calculate the learning rate schedule. batch_size = tf.placeholder(dtype=tf.int32, shape=(), name='batch_size') num_batches_per_epoch = (dataset.num_examples_per_epoch() / FLAGS.batch_size) # Decay steps need to be divided by the number of replicas to aggregate. decay_steps = int(num_batches_per_epoch * FLAGS.num_epochs_per_decay / num_replicas_to_aggregate) # Decay the learning rate exponentially based on the number of steps. lr = tf.train.exponential_decay( FLAGS.initial_learning_rate, global_step, decay_steps, FLAGS.learning_rate_decay_factor, staircase=True) # Add a summary to track the learning rate. # tf.summary.scalar('learning_rate', lr) # Create an optimizer that performs gradient descent. images, labels = image_processing.distorted_inputs( dataset, batch_size, num_preprocess_threads=FLAGS.num_preprocess_threads) print(images.get_shape()) print(labels.get_shape()) # Number of classes in the Dataset label set plus 1. # Label 0 is reserved for an (unused) background class. # num_classes = dataset.num_classes() + 1 num_classes = dataset.num_classes() print(num_classes) # logits = inception.inference(images, num_classes, for_training=True) network_fn = nets_factory.get_network_fn( 'inception_v3', num_classes=num_classes) (logits, _) = network_fn(images) print(logits.get_shape()) # Add classification loss. # inception.loss(logits, labels, batch_size) # Gather all of the losses including regularization losses. labels = tf.one_hot(labels, 1000, 1, 0) cross_entropy = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=labels) # losses = tf.get_collection(slim.losses.LOSSES_COLLECTION) # losses += tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) total_loss = cross_entropy + _WEIGHT_DECAY * tf.add_n( [tf.nn.l2_loss(v) for v in tf.trainable_variables()]) # total_loss = tf.add_n(losses, name='total_loss') loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg') loss_averages_op = loss_averages.apply(losses + [total_loss]) with tf.control_dependencies([loss_averages_op]): opt = tf.train.RMSPropOptimizer(lr, RMSPROP_DECAY, momentum=RMSPROP_MOMENTUM, epsilon=RMSPROP_EPSILON) grads0 = opt.compute_gradients(total_loss) grads = [(tf.scalar_mul( tf.cast(batch_size / FLAGS.batch_size, tf.float32), grad), var) for grad, var in grads0] total_loss = tf.identity(total_loss) exp_moving_averager = tf.train.ExponentialMovingAverage( MOVING_AVERAGE_DECAY, global_step) variables_averages_op = exp_moving_averager.apply( tf.trainable_variables()) apply_gradients_op = opt.apply_gradients( grads, global_step=global_step) with tf.control_dependencies( [apply_gradients_op, variables_averages_op]): train_op = tf.identity(total_loss, name='train_op') # Get chief queue_runners and init_tokens, which is used to synchronize # replicas. More details can be found in SyncReplicasOptimizer. # chief_queue_runners = [opt.get_chief_queue_runner()] # init_tokens_op = opt.get_init_tokens_op() # Create a saver. saver = tf.train.Saver() # Build the summary operation based on the TF collection of Summaries. # summary_op = tf.summary.merge_all() # Build an initialization operation to run below. init_op = tf.global_variables_initializer() # We run the summaries in the same thread as the training operations by # passing in None for summary_op to avoid a summary_thread being started. # Running summaries and training operations in parallel could run out of # GPU memory. sv = tf.train.Supervisor( is_chief=is_chief, logdir=FLAGS.train_dir, init_op=init_op, summary_op=None, global_step=global_step, recovery_wait_secs=1, saver=None, save_model_secs=FLAGS.save_interval_secs) tf.logging.info('%s Supervisor' % datetime.now()) sess_config = tf.ConfigProto( allow_soft_placement=True, log_device_placement=FLAGS.log_device_placement) # Get a session. sess = sv.prepare_or_wait_for_session(server.target, config=sess_config) # Start the queue runners. queue_runners = tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS) sv.start_queue_runners(sess, queue_runners) tf.logging.info('Started %d queues for processing input data.', len(queue_runners)) # if is_chief: # sv.start_queue_runners(sess, chief_queue_runners) # sess.run(init_tokens_op) # Train, checking for Nans. Concurrently run the summary operation at a # specified interval. Note that the summary_op and train_op never run # simultaneously in order to prevent running out of GPU memory. # next_summary_time = time.time() + FLAGS.save_summaries_secs step = 0 time0 = time.time() batch_size_num = 1 while not sv.should_stop(): try: start_time = time.time() batch_size_num = 32 # batch_size_num = int((int(step)/3*10)) % 100000 + 1 # if step < 5: # batch_size_num = 32 # batch_size_num = (batch_size_num ) % 64 + 1 # else: # batch_size_num = 80 run_options = tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() my_images, loss_value, step = sess.run( [images, train_op, global_step], feed_dict={batch_size: batch_size_num}, options=run_options, run_metadata=run_metadata) b = time.time() # assert not np.isnan(loss_value), 'Model diverged with loss = NaN' if step > FLAGS.max_steps: break duration = time.time() - start_time # thread = threading2.Thread(target=get_computation_time, name="get_computation_time",args=(run_metadata.step_stats,step,)) # thread.start() # tl = timeline.Timeline(run_metadata.step_stats) # last_batch_time = tl.get_local_step_duration('sync_token_q_Dequeue') c0 = time.time() # batch_size_num = batchSizeManager.dictate_new_batch_size(FLAGS.task_id, last_batch_time) # batch_size_num = rpcClient.update_batch_size(FLAGS.task_id, last_batch_time, available_cpu, available_memory, step, batch_size_num) # batch_size_num = rpcClient.update_batch_size(FLAGS.task_id, 0,0,0, step, batch_size_num) # ctf = tl.generate_chrome_trace_format() # with open("timeline.json", 'a') as f: # f.write(ctf) if step % 1 == 0: examples_per_sec = FLAGS.batch_size / float( duration) c = time.time() tf.logging.info("time statistics" + " - train_time: " + str(b - start_time) + " - get_batch_time: " + str(c0 - b) + " - get_bs_time: " + str(c - c0) + " - accum_time: " + str(c - time0) + " - batch_size: " + str(batch_size_num)) format_str = ( 'Worker %d: %s: step %d, loss = %.2f' '(%.1f examples/sec; %.3f sec/batch)') tf.logging.info( format_str % (FLAGS.task_id, datetime.now(), step, loss_value, examples_per_sec, duration)) # Determine if the summary_op should be run on the chief worker. # if is_chief and next_summary_time < time.time(): # tf.logging.info('Running Summary operation on the chief.') # summary_str = sess.run(summary_op) # sv.summary_computed(sess, summary_str) # tf.logging.info('Finished running Summary operation.') # Determine the next time for running the summary. # next_summary_time += FLAGS.save_summaries_secs except: if is_chief: tf.logging.info( 'Chief got exception while running!') raise # Stop the supervisor. This also waits for service threads to finish. sv.stop()
def main(argv=None): ps_hosts = FLAGS.ps_hosts.split(',') worker_hosts = FLAGS.worker_hosts.split(',') tf.logging.info('PS hosts are: %s' % ps_hosts) tf.logging.info('Worker hosts are: %s' % worker_hosts) cluster_spec = tf.train.ClusterSpec({ 'ps': ps_hosts, 'worker': worker_hosts }) server = tf.train.Server({ 'ps': ps_hosts, 'worker': worker_hosts }, job_name=FLAGS.job_name, task_index=FLAGS.task_id, protocol=FLAGS.protocol) sspManager = SspManager(len(worker_hosts), 5) if FLAGS.job_name == 'ps': if FLAGS.task_id == 0: rpcServer = sspManager.create_rpc_server(ps_hosts[0].split(':')[0]) rpcServer.serve() server.join() time.sleep(5) rpcClient = sspManager.create_rpc_client(ps_hosts[0].split(':')[0]) dataset = ImagenetData(subset=FLAGS.subset) assert dataset.data_files() is_chief = (FLAGS.task_id == 0) if is_chief: if not tf.gfile.Exists(FLAGS.train_dir): tf.gfile.MakeDirs(FLAGS.train_dir) num_workers = len(cluster_spec.as_dict()['worker']) num_parameter_servers = len(cluster_spec.as_dict()['ps']) with tf.device('/job:worker/task:%d' % FLAGS.task_id): with slim.scopes.arg_scope( [slim.variables.variable, slim.variables.global_step], device=slim.variables.VariableDeviceChooser( num_parameter_servers)): '''Prepare Input''' global_step = slim.variables.global_step() batch_size = tf.placeholder(dtype=tf.int32, shape=(), name='batch_size') images, labels = image_processing.distorted_inputs( dataset, batch_size, num_preprocess_threads=FLAGS.num_preprocess_threads) num_classes = dataset.num_classes() + 1 '''Inference''' logits = inception.inference(images, num_classes, for_training=True) '''Loss''' inception.loss(logits, labels, batch_size) losses = tf.get_collection(slim.losses.LOSSES_COLLECTION) losses += tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) total_loss = tf.add_n(losses, name='total_loss') if is_chief: loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg') loss_averages_op = loss_averages.apply(losses + [total_loss]) with tf.control_dependencies([loss_averages_op]): total_loss = tf.identity(total_loss) '''Optimizer''' exp_moving_averager = tf.train.ExponentialMovingAverage( inception.MOVING_AVERAGE_DECAY, global_step) variables_to_average = (tf.trainable_variables() + tf.moving_average_variables()) num_batches_per_epoch = (dataset.num_examples_per_epoch() / FLAGS.batch_size) decay_steps = int(num_batches_per_epoch * FLAGS.num_epochs_per_decay / num_workers) lr = tf.train.exponential_decay(FLAGS.initial_learning_rate, global_step, decay_steps, FLAGS.learning_rate_decay_factor, staircase=True) opt = tf.train.RMSPropOptimizer(lr, RMSPROP_DECAY, momentum=RMSPROP_MOMENTUM, epsilon=RMSPROP_EPSILON) '''Train Operation''' batchnorm_updates = tf.get_collection( slim.ops.UPDATE_OPS_COLLECTION) assert batchnorm_updates, 'Batchnorm updates are missing' batchnorm_updates_op = tf.group(*batchnorm_updates) with tf.control_dependencies([batchnorm_updates_op]): total_loss = tf.identity(total_loss) naive_grads = opt.compute_gradients(total_loss) grads = [(tf.scalar_mul( tf.cast(batch_size / FLAGS.batch_size, tf.float32), grad), var) for grad, var in naive_grads] apply_gradients_op = opt.apply_gradients(grads, global_step=global_step) with tf.control_dependencies([apply_gradients_op]): train_op = tf.identity(total_loss, name='train_op') '''Supervisor and Session''' saver = tf.train.Saver() init_op = tf.global_variables_initializer() sv = tf.train.Supervisor(is_chief=is_chief, logdir=FLAGS.train_dir, init_op=init_op, summary_op=None, global_step=global_step, recovery_wait_secs=1, saver=saver, save_model_secs=FLAGS.save_interval_secs) tf.logging.info('%s Supervisor' % datetime.now()) sess_config = tf.ConfigProto( allow_soft_placement=True, log_device_placement=FLAGS.log_device_placement) sess = sv.prepare_or_wait_for_session(server.target, config=sess_config) queue_runners = tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS) '''Start Training''' sv.start_queue_runners(sess, queue_runners) tf.logging.info('Started %d queues for processing input data.', len(queue_runners)) batch_size_num = FLAGS.batch_size for step in range(FLAGS.max_steps): start_time = time.time() run_options = tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() loss_value, gs = sess.run( [train_op, global_step], feed_dict={batch_size: batch_size_num}, options=run_options, run_metadata=run_metadata) assert not np.isnan( loss_value), 'Model diverged with loss = NaN' duration = time.time() - start_time examples_per_sec = batch_size_num / float(duration) sec_per_batch = float(duration) format_str = ( "time: " + str(time.time()) + '; %s: step %d (gs %d), loss= %.2f (%.1f samples/s; %.3f s/batch)' ) tf.logging.info(format_str % (datetime.now(), step, gs, loss_value, examples_per_sec, sec_per_batch)) rpcClient.check_staleness(FLAGS.task_id, step)
"""Number of iterations to run.""") tf.app.flags.DEFINE_string('model_file', 'model/DCNet_', """Directory to save model""") is_training = tf.placeholder("bool") train_set = ImagenetData(subset='train') tr_images, tr_labels = alex2012_image_processing.distorted_inputs(train_set) val_set = ImagenetData(subset='validation') val_images, val_labels = alex2012_image_processing.inputs(val_set) images, labels = tf.cond(is_training, lambda: [tr_images, tr_labels], lambda: [val_images, val_labels]) cnn = VGG() cnn.build(images, train_set.num_classes(), is_training) fit_loss = loss2(cnn.score, labels, train_set.num_classes(), 'c_entropy') reg_loss = tf.add_n(tf.losses.get_regularization_losses()) orth_loss = tf.add_n(tf.get_collection('orth_constraint')) loss_op = fit_loss + orth_loss + reg_loss lr_ = tf.placeholder("float") weight_list = [v for v in tf.trainable_variables() if ('/filter' in v.name and 'score' not in v.name and 'shortcut' not in v.name)] assign_op_list = [] for v in weight_list: assign_op_list.append(tf.assign(v, cnn.sphere_dict[v.name])) assign_op = tf.group(*assign_op_list)