def cifar10_model_fn(features, labels, mode, params): tf.summary.image('image', features, max_outputs=6) network = resnet_model.cifar10_resnet_v2_generator( params['resnet_size'], _NUM_CLASSES, params['data_format']) inputs = tf.reshape(features, [-1,_HEIGHT, _WIDTH, _DEPTH]) logits = network(inputs, mode==tf.estimator.ModeKeys.TRAIN) predictions = {'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor')} if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) cross_entropy = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels) tf.identity(cross_entropy, name='cross_entropy') tf.summary.scalar('cross_entropy', cross_entropy) loss=cross_entropy + tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables()]) if mode == tf.estimator.ModeKeys.TRAIN: initial_learning_rate = 0.1 * params['batch_size']/128 batches_per_epoch = _NUM_IMAGES['train'] / params['batch_size'] global_step = tf.train.get_global_step() boundaries = [int(batches_per_epoch*epoch) for epoch in [100,150,200]] values = [initial_learning_rate * decay for decay in [1, 0.1, 0.01, 0.001]] learning_rate = tf.train.piecewise_constant(tf.cast(global_step, tf.int32), boundaries, values) tf.identity(learning_rate, name='learning_rate') tf.summary.scalar('learning_rate', learning_rate) optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=_MOMENTUM) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss) else: train_op = None accuracy = tf.train.accuracy(tf.argmax(labels, axis=1),predictions['classes']) metrics={'accuracy':accuracy} tf.identity(accuracy[1], name='train_accuracy') tf.summary.scalar('train_accuracy',accuracy[1]) return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metric_ops=metrics)
def cifar10_model_fn(features, labels, mode): """Model function for CIFAR-10.""" network = resnet_model.cifar10_resnet_v2_generator( FLAGS.resnet_size, NUM_CLASSES) inputs = tf.reshape(features, [-1, HEIGHT, WIDTH, DEPTH]) logits = network(inputs, mode == tf.estimator.ModeKeys.TRAIN) predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
def create_loss(): """Creates loss tensor for resnet model.""" images = tf.random_uniform((BATCH_SIZE, HEIGHT, WIDTH, DEPTH)) labels = tf.random_uniform((BATCH_SIZE, NUM_CLASSES)) # channels_last for CPU if USE_TINY: network = resnet_model.tiny_cifar10_resnet_v2_generator(RESNET_SIZE, NUM_CLASSES, data_format='channels_last') else: network = resnet_model.cifar10_resnet_v2_generator(RESNET_SIZE, NUM_CLASSES, data_format='channels_last') inputs = tf.reshape(images, [BATCH_SIZE, HEIGHT, WIDTH, DEPTH]) logits = network(inputs,True) cross_entropy = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels) l2_penalty = tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables()]) loss = cross_entropy + _WEIGHT_DECAY * l2_penalty return loss
def create_loss(): """Creates loss tensor for resnet model.""" images = tf.random_uniform((BATCH_SIZE, HEIGHT, WIDTH, DEPTH)) labels = tf.random_uniform((BATCH_SIZE, NUM_CLASSES)) # channels_last for CPU if USE_TINY: network = resnet_model.tiny_cifar10_resnet_v2_generator(RESNET_SIZE, NUM_CLASSES, data_format='channels_last') else: network = resnet_model.cifar10_resnet_v2_generator(RESNET_SIZE, NUM_CLASSES, data_format='channels_last') inputs = tf.reshape(images, [BATCH_SIZE, HEIGHT, WIDTH, DEPTH]) logits = network(inputs,True) cross_entropy = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels) l2_penalty = tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables()]) loss = cross_entropy + _WEIGHT_DECAY * l2_penalty return loss
def evaluate(): """Eval CIFAR-10 for a number of steps.""" os.environ['CUDA_VISIBLE_DEVICES'] = '' with tf.Graph().as_default() as g: # Get images and labels for CIFAR-10. eval_data = FLAGS.eval_data == 'test' # images, labels = cifar10.inputs(eval_data=eval_data) images, labels = cifar10.inputs(eval_data=eval_data) # images, labels = cifar10.distorted_inputs(128) # Build a Graph that computes the logits predictions from the # inference model. with tf.variable_scope('root'): network = resnet_model.cifar10_resnet_v2_generator( FLAGS.resnet_size, _NUM_CLASSES) logits = network(images, True) # Calculate predictions. top_k_op = tf.nn.in_top_k(logits, labels, 1) # Restore the moving average version of the learned variables for eval. # variable_averages = tf.train.ExponentialMovingAverage( # cifar10.MOVING_AVERAGE_DECAY) # variables_to_restore = variable_averages.variables_to_restore() # saver = tf.train.Saver(variables_to_restore) saver = tf.train.Saver() # Build the summary operation based on the TF collection of Summaries. # summary_op = tf.merge_all_summaries() summary_op = tf.summary.merge_all() summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g) while True: eval_once(saver, summary_writer, top_k_op, summary_op) tf.logging.info('continue') if FLAGS.run_once: break time.sleep(FLAGS.eval_interval_secs)
def cifar10_model_fn(features, labels, mode): """Model function for CIFAR-10.""" ##temporary solution to run load only once global load_done tf.summary.image('images', features, max_outputs=6) network = resnet.cifar10_resnet_v2_generator(FLAGS.resnet_size, NUM_CLASSES) inputs = tf.reshape(features, [-1, HEIGHT, WIDTH, DEPTH]) logits = network(inputs, mode == tf.estimator.ModeKeys.TRAIN) ## All the required modifications are here. ## Since this routine uses tf.estimators to implement ResNet, if we would like to ## dump or reload the data, only method of communication is Checkpoints. ## Therefore, each commands are seperately called. ## Load the data## ## Executed only if the retrain flag is set ## Also, load happens only once in the retrain cycle RETRAIN = FLAGS.retrain if ((load_done == 0) and RETRAIN): load_done = 1 print("Loading pretrained weights") ## Load the modified/pre-trained weight values data = np.load("weights_cifar10.npy").item() addr = np.load("addr_table.npy").item() ## Path to the most recent Check-point file model_path = model_dir_saved + '/' + FLAGS.ckpt_file with tf.Session() as sess: ## All variables should be initialized sess.run(tf.global_variables_initializer()) ## Define the Saver instance vbased on the check-point file saver = tf.train.Saver(tf.global_variables()) saver.restore(sess, model_path) ## Get all the variable names in the required format model = get_weights(dump=False) get_sessions = model.keys() ## Go through every scope for i in get_sessions: if 'global_step' not in i: with tf.variable_scope(i, reuse=True): ## Go through every variable in the scope, with Reuse for val in model[i]: print('Loading weight variable: ' + val + ' in Scope: ' + i) ## Assign the loaded value to the weights var = tf.get_variable(val.split(":")[0], trainable=False) sess.run(var.assign(data[i][addr[val]])) ## Save the model in the Retrain directory saver.save(sess, FLAGS.model_dir + '/model.ckpt') print("data successfully loaded") ##Dump the Data## ## Set DUMP to False if you don't want to dump ## After the dump, the program ends DUMP = FLAGS.dump if ((mode == tf.estimator.ModeKeys.TRAIN) and (DUMP == True)): with tf.Session() as sess: print("Dumping weights now") weights_cifar10 = dict() ## All variables should be initialized sess.run(tf.global_variables_initializer()) ## Define the Saver instance vbased on the check-point file saver = tf.train.Saver(tf.global_variables()) saver.restore(sess, FLAGS.model_dir + '/' + FLAGS.ckpt_file) ## Get all the variables model = get_weights(dump=True) get_sessions = model.keys() ## Go through every scope for i in get_sessions: if 'global_step' not in i: with tf.variable_scope(i, reuse=True): ## Go through every variable in the scope, with Reuse layer_data = [] for val in model[i]: print('Dumping weight variable: ' + val + ' in Scope: ' + i) ## Append each variable value to a file layer_data.append( sess.run(tf.get_variable(val.split(":")[0]))) weights_cifar10[i] = layer_data ## Save them to a npy file np.save("weights_cifar10.npy", weights_cifar10) ## Exit the program sys.exit("Dump Finished, Exiting ...") predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) # Calculate loss, which includes softmax cross entropy and L2 regularization. cross_entropy = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels) # Create a tensor named cross_entropy for logging purposes. tf.identity(cross_entropy, name='cross_entropy') tf.summary.scalar('cross_entropy', cross_entropy) # Add weight decay to the loss. loss = cross_entropy + _WEIGHT_DECAY * tf.add_n( [tf.nn.l2_loss(v) for v in tf.trainable_variables()]) if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.train.get_or_create_global_step() # Multiply the learning rate by 0.1 at 100, 150, and 200 epochs. boundaries = [ int(_BATCHES_PER_EPOCH * epoch) for epoch in [100, 150, 200] ] values = [ _INITIAL_LEARNING_RATE * decay for decay in [1, 0.1, 0.01, 0.001] ] learning_rate = tf.train.piecewise_constant( tf.cast(global_step, tf.int32), boundaries, values) # Create a tensor named learning_rate for logging purposes tf.identity(learning_rate, name='learning_rate') tf.summary.scalar('learning_rate', learning_rate) optimizer = tf.train.GradientDescentOptimizer( learning_rate=learning_rate) # Batch norm requires update ops to be added as a dependency to the train_op update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step) else: train_op = None accuracy = tf.metrics.accuracy(tf.argmax(labels, axis=1), predictions['classes']) metrics = {'accuracy': accuracy} # Create a tensor named train_accuracy for logging purposes tf.identity(accuracy[1], name='train_accuracy') tf.summary.scalar('train_accuracy', accuracy[1]) return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metric_ops=metrics)
def save_usepre_model(): wordmodel_dir = "./tmp/cifar100_model_res%d_word_weightdecay%f/" % ( FLAGS.resnet_size, _WEIGHT_DECAY) network = resnet_model.cifar10_resnet_v2_generator(FLAGS.resnet_size, _NUM_CLASSES, FLAGS.data_format, FLAGS.more_layer) x = tf.placeholder(tf.float32, [None, 32, 32, 3]) labels = tf.placeholder(tf.float32, [None, _NUM_CLASSES]) logits = network(x, True) predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } # Calculate loss, which includes softmax cross entropy and L2 regularization. cross_entropy = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels) # Create a tensor named cross_entropy for logging purposes. tf.identity(cross_entropy, name='cross_entropy') # Add weight decay to the loss. loss = cross_entropy + _WEIGHT_DECAY * tf.add_n( [tf.nn.l2_loss(v) for v in tf.trainable_variables()]) # Scale the learning rate linearly with the batch size. When the batch size # is 128, the learning rate should be 0.1. initial_learning_rate = 0.1 batches_per_epoch = _NUM_IMAGES['train'] / 128 global_step = tf.train.get_or_create_global_step() # Multiply the learning rate by 0.1 at 100, 150, and 200 epochs. boundaries = [int(batches_per_epoch * epoch) for epoch in [100, 150, 200]] values = [initial_learning_rate * decay for decay in [1, 0.1, 0.01, 0.001]] learning_rate = tf.train.piecewise_constant(tf.cast(global_step, tf.int32), boundaries, values) # Create a tensor named learning_rate for logging purposes tf.identity(learning_rate, name='learning_rate') optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=_MOMENTUM) # Batch norm requires update ops to be added as a dependency to the train_op update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step) ckpt = tf.train.get_checkpoint_state(wordmodel_dir) model_restore_path = ckpt.model_checkpoint_path all_var = tf.trainable_variables() restorelist = [] for key in all_var: print(key) if not "dense_1" in key.name: # and not "batch_normalization" in key.name: #print(key.name) restorelist.append(key) sess = tf.Session() sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(restorelist) saver.restore(sess, model_restore_path) saver = tf.train.Saver() save_path = saver.save(sess, FLAGS.model_dir + "/model.ckpt") sess.close() print("Model saved in file: %s" % save_path)
def cifar100_model_fn(features, labels, mode, params): """Model function for CIFAR-10.""" tf.summary.image('images', features, max_outputs=6) network = resnet_model.cifar10_resnet_v2_generator( params['resnet_size'], _NUM_CLASSES, params['data_format'], more_layer=FLAGS.more_layer) inputs = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _DEPTH]) logits = network(inputs, mode == tf.estimator.ModeKeys.TRAIN) print(logits) predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) dic_tensor_cifar100 = tf.convert_to_tensor(np_cifar100) # Calculate loss, which includes softmax cross entropy and L2 regularization. if FLAGS.wordvec: cross_entropy = tf.reduce_mean( tf.reduce_sum(tf.squared_difference(labels, logits), 1)) else: cross_entropy = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels) # Create a tensor named cross_entropy for logging purposes. tf.identity(cross_entropy, name='cross_entropy') tf.summary.scalar('cross_entropy', cross_entropy) # Add weight decay to the loss. loss = cross_entropy + _WEIGHT_DECAY * tf.add_n( [tf.nn.l2_loss(v) for v in tf.trainable_variables()]) tf.identity(loss, name='loss') tf.summary.scalar('loss', loss) if mode == tf.estimator.ModeKeys.TRAIN: # Scale the learning rate linearly with the batch size. When the batch size # is 128, the learning rate should be 0.1. initial_learning_rate = 0.1 * params['batch_size'] / 128 batches_per_epoch = _NUM_IMAGES['train'] / params['batch_size'] global_step = tf.train.get_or_create_global_step() # Multiply the learning rate by 0.1 at 100, 150, and 200 epochs. #boundaries = [int(batches_per_epoch * epoch) for epoch in [100, 180, 300]] #values = [initial_learning_rate * decay for decay in [0.2, 0.04, 0.02, 0.01]] boundaries = [ int(batches_per_epoch * epoch) for epoch in FLAGS.lr_boundaries ] values = [initial_learning_rate * decay for decay in FLAGS.lr_values] learning_rate = tf.train.piecewise_constant( tf.cast(global_step, tf.int32), boundaries, values) # Create a tensor named learning_rate for logging purposes tf.identity(learning_rate, name='learning_rate') tf.summary.scalar('learning_rate', learning_rate) optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=_MOMENTUM) # Batch norm requires update ops to be added as a dependency to the train_op update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step) else: train_op = None accuracy = tf.metrics.accuracy(tf.argmax(labels, axis=1), predictions['classes']) metrics = {'accuracy': accuracy} # Create a tensor named train_accuracy for logging purposes tf.identity(accuracy[1], name='train_accuracy') tf.summary.scalar('train_accuracy', accuracy[1]) summary_hook = tf.train.SummarySaverHook(save_steps=100, output_dir=FLAGS.log_dir, summary_op=tf.summary.merge_all()) return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metric_ops=metrics, training_hooks=[summary_hook])
def train(): ps_hosts = FLAGS.ps_hosts.split(',') worker_hosts = FLAGS.worker_hosts.split(',') print('PS hosts are: %s' % ps_hosts) print('Worker hosts are: %s' % worker_hosts) configP = tf.ConfigProto() server = tf.train.Server({ 'ps': ps_hosts, 'worker': worker_hosts }, job_name=FLAGS.job_name, task_index=FLAGS.task_id, config=configP) if FLAGS.job_name == 'ps': server.join() is_chief = (FLAGS.task_id == 0) if is_chief: if tf.gfile.Exists(FLAGS.train_dir): tf.gfile.DeleteRecursively(FLAGS.train_dir) tf.gfile.MakeDirs(FLAGS.train_dir) device_setter = tf.train.replica_device_setter(ps_tasks=len(ps_hosts)) with tf.device('/job:worker/task:%d' % FLAGS.task_id): with tf.device(device_setter): """Prepare Input""" global_step = tf.Variable(0, trainable=False) decay_steps = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * NUM_EPOCHS_PER_DECAY / FLAGS.batch_size batch_size = tf.placeholder(dtype=tf.int32, shape=(), name='batch_size') with tf.device('/cpu:0'): images, labels = cifar10.distorted_inputs(batch_size) inputs = tf.reshape(images, [-1, _HEIGHT, _WIDTH, _DEPTH]) """Inference""" with tf.variable_scope('root', partitioner=tf.fixed_size_partitioner( len(ps_hosts), axis=0)): network = resnet_model.cifar10_resnet_v2_generator( FLAGS.resnet_size, _NUM_CLASSES) logits = network(inputs, True) labels = tf.cast(labels, tf.int64) correct_prediction = tf.equal(tf.argmax(logits, 1), labels) correct_prediction = tf.cast(correct_prediction, tf.float32) accuracy_op = tf.reduce_mean(correct_prediction) """Loss""" labels = tf.one_hot(labels, 10, 1, 0) cross_entropy = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=labels) loss = cross_entropy + _WEIGHT_DECAY * tf.add_n( [tf.nn.l2_loss(v) for v in tf.trainable_variables()]) """Define Optimization""" # Decay the learning rate exponentially based on the number of steps. lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE * len(worker_hosts), global_step, decay_steps, LEARNING_RATE_DECAY_FACTOR, staircase=True) opt = tf.train.GradientDescentOptimizer(lr) # Track the moving averages of all trainable variables. exp_moving_averager = tf.train.ExponentialMovingAverage( MOVING_AVERAGE_DECAY, global_step) variables_to_average = (tf.trainable_variables() + tf.moving_average_variables()) opt = tf.train.SyncReplicasOptimizer( opt, replicas_to_aggregate=len(worker_hosts), total_num_replicas=len(worker_hosts), variable_averages=exp_moving_averager, variables_to_average=variables_to_average) # Compute gradients with respect to the loss. grads = opt.compute_gradients(loss) apply_gradients_op = opt.apply_gradients(grads, global_step=global_step) with tf.control_dependencies([apply_gradients_op]): train_op = tf.identity(loss, name='train_op') """Sychronization Management""" if is_chief: chief_queue_runners = [opt.get_chief_queue_runner()] init_tokens_op = opt.get_init_tokens_op() saver = tf.train.Saver(max_to_keep=1) sv = tf.train.Supervisor(is_chief=is_chief, logdir=FLAGS.train_dir, init_op=tf.group( tf.global_variables_initializer(), tf.local_variables_initializer()), summary_op=None, global_step=global_step, saver=saver, recovery_wait_secs=1, save_model_secs=60) tf.logging.info('%s Supervisor' % datetime.now()) """Train CIFAR-10 for a number of steps.""" sess_config = tf.ConfigProto( allow_soft_placement=True, log_device_placement=FLAGS.log_device_placement) sess = sv.prepare_or_wait_for_session(server.target, config=sess_config) queue_runners = tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS) sv.start_queue_runners(sess, queue_runners) if is_chief: sv.start_queue_runners(sess, chief_queue_runners) sess.run(init_tokens_op) batch_size_num = FLAGS.batch_size for step in range(init_global_step, FLAGS.max_steps): step_start_time = time.time() run_options = tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / batch_size_num decay_steps_num = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY) _, loss_value, gs = sess.run( [train_op, loss, global_step], feed_dict={batch_size: batch_size_num}, options=run_options, run_metadata=run_metadata) duration = time.time() - step_start_time num_examples_per_step = batch_size_num examples_per_sec = num_examples_per_step / duration sec_per_batch = float(duration) format_str = ( "time: " + str(time.time()) + '; %s: step %d (gs %d), loss= %.2f (%.1f samples/s; %.3f s/batch)' ) tf.logging.info(format_str % (datetime.now(), step, gs, loss_value, examples_per_sec, sec_per_batch)) """Do evaluation on accuracy (this is not testset evaluation)""" if step % 200 == 0: accuracy = sess.run(accuracy_op, feed_dict={batch_size: 10000}) tf.logging.info('evaluation: step - ' + str(step) + '; accuracy: ' + str(accuracy))
def cifar10_model_fn(features, labels, mode): """Model function for CIFAR-10.""" tf.summary.image('images', features, max_outputs=6) with tf.device('/gpu:3'): network = resnet_model.cifar10_resnet_v2_generator( FLAGS.resnet_size, _NUM_CLASSES) inputs = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _DEPTH]) logits = network(inputs, mode == tf.estimator.ModeKeys.TRAIN) predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) # Calculate loss, which includes softmax cross entropy and L2 regularization. cross_entropy = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels) # Create a tensor named cross_entropy for logging purposes. tf.identity(cross_entropy, name='cross_entropy') tf.summary.scalar('cross_entropy', cross_entropy) # Add weight decay to the loss. loss = cross_entropy + _WEIGHT_DECAY * tf.add_n( [tf.nn.l2_loss(v) for v in tf.trainable_variables()]) if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.train.get_or_create_global_step() # Multiply the learning rate by 0.1 at 100, 150, and 200 epochs. boundaries = [ int(_BATCHES_PER_EPOCH * epoch) for epoch in [100, 150, 200] ] values = [ _INITIAL_LEARNING_RATE * decay for decay in [1, 0.1, 0.01, 0.001] ] learning_rate = tf.train.piecewise_constant( tf.cast(global_step, tf.int32), boundaries, values) # Create a tensor named learning_rate for logging purposes tf.identity(learning_rate, name='learning_rate') tf.summary.scalar('learning_rate', learning_rate) optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=_MOMENTUM) # Batch norm requires update ops to be added as a dependency to the train_op update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step) else: train_op = None accuracy = tf.metrics.accuracy(tf.argmax(labels, axis=1), predictions['classes']) metrics = {'accuracy': accuracy} # Create a tensor named train_accuracy for logging purposes tf.identity(accuracy[1], name='train_accuracy') tf.summary.scalar('train_accuracy', accuracy[1]) return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metric_ops=metrics)
def model_fn(features, labels, mode): """ Model function for CIFAR-10. For more information: https://www.tensorflow.org/guide/custom_estimators#write_a_model_function """ inputs = features[INPUT_TENSOR_NAME] tf.summary.image('images', inputs, max_outputs=6) network = resnet_model.cifar10_resnet_v2_generator(RESNET_SIZE, NUM_CLASSES) inputs = tf.reshape(inputs, [-1, HEIGHT, WIDTH, DEPTH]) logits = network(inputs, mode == tf.estimator.ModeKeys.TRAIN) predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } if mode == tf.estimator.ModeKeys.PREDICT: export_outputs = { SIGNATURE_NAME: tf.estimator.export.PredictOutput(predictions) } return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, export_outputs=export_outputs) # Calculate loss, which includes softmax cross entropy and L2 regularization. cross_entropy = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=tf.one_hot(labels, 10)) # Create a tensor named cross_entropy for logging purposes. tf.identity(cross_entropy, name='cross_entropy') tf.summary.scalar('cross_entropy', cross_entropy) # Add weight decay to the loss. loss = cross_entropy + _WEIGHT_DECAY * tf.add_n( [tf.nn.l2_loss(v) for v in tf.trainable_variables()]) if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.train.get_or_create_global_step() # Multiply the learning rate by 0.1 at 100, 150, and 200 epochs. boundaries = [int(_BATCHES_PER_EPOCH * epoch) for epoch in [100, 150, 200]] values = [_INITIAL_LEARNING_RATE * decay for decay in [1, 0.1, 0.01, 0.001]] learning_rate = tf.train.piecewise_constant( tf.cast(global_step, tf.int32), boundaries, values) # Create a tensor named learning_rate for logging purposes tf.identity(learning_rate, name='learning_rate') tf.summary.scalar('learning_rate', learning_rate) optimizer = tf.train.MomentumOptimizer( learning_rate=learning_rate, momentum=_MOMENTUM) # Batch norm requires update ops to be added as a dependency to the train_op update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step) else: train_op = None return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, loss=loss, train_op=train_op)
def cifar10_model_fn(features, labels, mode, params): """Model function for CIFAR-10.""" network = resnet_model.cifar10_resnet_v2_generator( params['resnet_size'], _NUM_CLASSES, params['data_format'] ) inputs = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _DEPTH]) clabels = labels[:, :_NUM_CLASSES] logits = network(inputs, mode == tf.estimator.ModeKeys.TRAIN, name="main") probs = tf.nn.softmax(logits, axis=1) # Calculate loss, which includes softmax cross entropy base_loss = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=clabels), name="base_loss" ) loss = base_loss loss = tf.identity(loss, name="loss_vec") loss_sum = tf.summary.scalar("loss", loss) rate = tf.reduce_max(probs, axis=1) # print extra stuff here classes = tf.argmax(logits, axis=1) accuracy_m = tf.metrics.accuracy( tf.argmax(clabels, axis=1), classes, name="accuracy_metric") accuracy = tf.identity(accuracy_m[1], name="accuracy_vec") accuracy_sum = tf.summary.scalar("accuracy", accuracy) if mode == tf.estimator.ModeKeys.EVAL or params["predict"]: # print # note this is labels not clabels print_labels = tf.argmax(labels, axis=1) print_rate = rate print_probs = probs print_logits = logits hooks = [] eval_metric_ops = { "accuracy": accuracy_m } # # printing stuff if predict if params["predict"]: loss = tf.Print(loss, [print_labels], summarize=1000000, message='Targets') loss = tf.Print(loss, [print_rate], summarize=1000000, message='Rate') loss = tf.Print(loss, [print_probs], summarize=1000000, message='Probs') loss = tf.Print(loss, [print_logits], summarize=1000000, message='Logits') hooks = [] eval_metric_ops = {} return tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops = eval_metric_ops # evaluation_hooks=hooks, ) if mode == tf.estimator.ModeKeys.TRAIN: # Scale the learning rate linearly with the batch size. When the batch size # is 128, the learning rate should be 0.1. initial_learning_rate = 0.1 * params['batch_size'] / 128 batches_per_epoch = _NUM_IMAGES['train'] / params['batch_size'] global_step = tf.train.get_or_create_global_step() # Multiply the learning rate by 0.1 at 100, 150, and 200 epochs. boundaries = [int(batches_per_epoch * epoch) for epoch in [100, 150, 200]] values = [initial_learning_rate * decay for decay in [1, 0.1, 0.01, 0.001]] learning_rate = tf.train.piecewise_constant( tf.cast(global_step, tf.int32), boundaries, values, name="learning_rate_vec") learning_rate_sum = tf.summary.scalar("learning_rate", learning_rate) optimizer = tf.train.MomentumOptimizer( learning_rate=learning_rate, momentum=_MOMENTUM ) # Batch norm requires update ops to be added as a dependency to the train_op update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step) hook = tf.train.SummarySaverHook( summary_op=tf.summary.merge([accuracy_sum, learning_rate_sum]), save_steps=1, ) return tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, training_hooks=[hook], )
def cifar10_logit_model_fn(features, labels, mode, params): """Model function for CIFAR-10.""" network = resnet_model.cifar10_resnet_v2_generator(params['resnet_size'], _NUM_CLASSES, params['data_format']) inputs = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _DEPTH]) logits = network(inputs, mode == tf.estimator.ModeKeys.TRAIN) # adding logit 0 for NOTA if params["variant"] != "none": logits = tf.pad(logits, [[0, 0], [0, 1]], "CONSTANT") classes = tf.argmax(logits, axis=1) # if mode == tf.estimator.ModeKeys.PREDICT: # predictions = { # 'classes': classes, # 'probabilities': tf.nn.softmax(logits, name='softmax_tensor'), # } # return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) accuracy = tf.metrics.accuracy(tf.argmax(labels, axis=1), classes, name="accuracy_metric") accuracy = tf.identity(accuracy[1], name="accuracy_vec") accuracy_sum = tf.summary.scalar("accuracy", accuracy) # Calculate loss, which includes softmax cross entropy base_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2( logits=logits, labels=labels), name="base_loss") lnfactor = 0 pvals = 1 - labels pvals = pvals / tf.reduce_sum(pvals, axis=-1, keepdims=True) distr = tfp.distributions.Categorical(probs=pvals) neg_samples = tf.transpose(distr.sample([_NUM_PEN_CLASSES])) mask = tf.one_hot(neg_samples, depth=_NUM_CLASSES + 1, axis=1) mask = tf.reduce_sum(mask, axis=(-1)) if params["variant"] == "den": lnfactor = tf.reduce_mean( tf.log(1 + tf.reduce_sum(tf.exp(logits) * mask, axis=1))) elif params["variant"] == "num": # todo: update this accrding to the new custom # change this lnfactor = -custom_softmax_cross_entropy( logits=logits, labels=mask) / _NUM_PEN_CLASSES elif params["variant"] == "pen": neg_logits_mean = tf.reduce_mean(mask * logits, axis=1) lnfactor = tf.reduce_mean(tf.square( tf.nn.softplus(-neg_logits_mean))) / (_NUM_PEN_CLASSES * 20) elif params["variant"] == "cen": std = tf.get_variable(name="std_logits", shape=(1), initializer=tf.ones_initializer(), trainable=True) distr = tfp.distributions.MultivariateNormalDiag( loc=tf.zeros(_NUM_CLASSES + 1), scale_identity_multiplier=std, ) probs = distr.prob(logits) lnfactor = -tf.reduce_mean(tf.log(probs + EPSILON)) loss = base_loss + params["lamb"] * lnfactor loss = tf.identity(loss, name="loss_vec") loss_sum = tf.summary.scalar("loss", loss) if mode == tf.estimator.ModeKeys.EVAL or mode == tf.estimator.ModeKeys.PREDICT: # # printing stuff if predict if mode == tf.estimator.ModeKeys.PREDICT or params["predict"]: loss = tf.Print(loss, [tf.argmax(labels, 1)], summarize=1000000, message='Targets') loss = tf.Print(loss, [tf.argmax(logits, 1)], summarize=1000000, message='Predictions') loss = tf.Print(loss, [tf.nn.softmax(logits)], summarize=1000000, message='Probs') loss = tf.Print(loss, [logits], summarize=1000000, message='Logits') hook = tf.train.SummarySaverHook( summary_op=tf.summary.merge([accuracy_sum]), output_dir=os.path.join(params["model_dir"], "eval_core"), save_steps=1, ) return tf.estimator.EstimatorSpec( mode=mode, loss=loss, evaluation_hooks=[hook], ) if mode == tf.estimator.ModeKeys.TRAIN: # Scale the learning rate linearly with the batch size. When the batch size # is 128, the learning rate should be 0.1. initial_learning_rate = 0.1 * params['batch_size'] / 128 batches_per_epoch = _NUM_IMAGES['train'] / params['batch_size'] global_step = tf.train.get_or_create_global_step() # Multiply the learning rate by 0.1 at 100, 150, and 200 epochs. boundaries = [ int(batches_per_epoch * epoch) for epoch in [100, 150, 200] ] values = [ initial_learning_rate * decay for decay in [1, 0.1, 0.01, 0.001] ] learning_rate = tf.train.piecewise_constant(tf.cast( global_step, tf.int32), boundaries, values, name="learning_rate_vec") learning_rate_sum = tf.summary.scalar("learning_rate", learning_rate) optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=_MOMENTUM) # Batch norm requires update ops to be added as a dependency to the train_op update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step) hook = tf.train.SummarySaverHook( summary_op=tf.summary.merge([accuracy_sum, learning_rate_sum]), save_steps=1, ) return tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, training_hooks=[hook], )
def train(): global updated_batch_size_num global passed_info global shall_update ps_hosts = FLAGS.ps_hosts.split(',') worker_hosts = FLAGS.worker_hosts.split(',') print ('PS hosts are: %s' % ps_hosts) print ('Worker hosts are: %s' % worker_hosts) issync = FLAGS.sync cluster = tf.train.ClusterSpec({'ps': ps_hosts, 'worker': worker_hosts}) server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) if FLAGS.job_name == 'ps': server.join() elif FLAGS.job_name == "worker": time.sleep(10) is_chief = (FLAGS.task_index == 0) if is_chief: if tf.gfile.Exists(FLAGS.train_dir): tf.gfile.DeleteRecursively(FLAGS.train_dir) tf.gfile.MakeDirs(FLAGS.train_dir) # modified by faye with tf.device(tf.train.replica_device_setter( worker_device="/job:worker/task:%d" % FLAGS.task_index, cluster=cluster )): global_step = tf.get_variable( 'global_step', [], initializer=tf.constant_initializer(0), trainable=False) decay_steps = 50000*350.0/FLAGS.batch_size batch_size = tf.placeholder(dtype=tf.int32, shape=(), name='batch_size') images, labels = cifar10.distorted_inputs(batch_size) print('zx0') print(images.get_shape().as_list()) # print (str(tf.shape(images))+ str(tf.shape(labels))) re = tf.shape(images)[0] network = resnet_model.cifar10_resnet_v2_generator(FLAGS.resnet_size, _NUM_CLASSES) inputs = tf.reshape(images, [-1, _HEIGHT, _WIDTH, _DEPTH]) # labels = tf.reshape(labels, [-1, _NUM_CLASSES]) labels = tf.one_hot(labels, 10, 1, 0) logits = network(inputs, True) print(logits.get_shape()) cross_entropy = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=labels) loss = cross_entropy + _WEIGHT_DECAY * tf.add_n( [tf.nn.l2_loss(v) for v in tf.trainable_variables()]) # Decay the learning rate exponentially based on the number of steps. lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE, global_step, decay_steps, LEARNING_RATE_DECAY_FACTOR, staircase=True) opt = tf.train.GradientDescentOptimizer(lr) # Track the moving averages of all trainable variables. exp_moving_averager = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step) variables_to_average = (tf.trainable_variables() + tf.moving_average_variables()) variables_averages_op = exp_moving_averager.apply(tf.trainable_variables()) # added by faye #grads = opt.compute_gradients(loss) grads0 = opt.compute_gradients(loss) grads = [(tf.scalar_mul(tf.cast(batch_size/FLAGS.batch_size, tf.float32), grad), var) for grad, var in grads0] if issync == 1: opt = tf.train.SyncReplicasOptimizer( opt, replicas_to_aggregate=len(worker_hosts), # replica_id=FLAGS.task_id, total_num_replicas=len(worker_hosts), variable_averages=exp_moving_averager, variables_to_average=variables_to_average) apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) chief_queue_runners = opt.get_chief_queue_runner() init_tokens_op = opt.get_init_tokens_op() else: apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) train_op = tf.group(apply_gradient_op, variables_averages_op) sv = tf.train.Supervisor(is_chief=is_chief, logdir=FLAGS.train_dir, init_op=tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()), summary_op=None, global_step=global_step, # saver=saver, saver=None, recovery_wait_secs=1, save_model_secs=60) tf.logging.info('%s Supervisor' % datetime.now()) sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=FLAGS.log_device_placement) # Get a session. sess = sv.prepare_or_wait_for_session(server.target, config=sess_config) # sess.run(tf.global_variables_initializer()) # Start the queue runners. if is_chief and issync == 1: sess.run(init_tokens_op) sv.start_queue_runners(sess, [chief_queue_runners]) else: sv.start_queue_runners(sess=sess) #sess.run(init_tokens_op) #"""Train CIFAR-10 for a number of steps.""" step = 0 g_step = 0 batch_size_num = FLAGS.batch_size while g_step <= FLAGS.max_steps: start_time = time.time() run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() if step <= 5: batch_size_num = FLAGS.batch_size num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / batch_size_num decay_steps_num = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY) _, loss_value, g_step = sess.run([train_op, loss, global_step], feed_dict={batch_size: batch_size_num}, options=run_options, run_metadata=run_metadata) tl = timeline.Timeline(run_metadata.step_stats) ctf = tl.generate_chrome_trace_format() if step % 1 == 0: duration = time.time() - start_time num_examples_per_step = batch_size_num examples_per_sec = num_examples_per_step / duration sec_per_batch = float(duration) format_str = ('%s: step %d (global_step %d), loss = %.2f (%.1f examples/sec; %.3f sec/batch)') tf.logging.info(format_str % (datetime.now(), step, g_step, loss_value, examples_per_sec, sec_per_batch)) step += 1 # end of while sv.stop()
def cifar10_model_fn(config, features, labels, mode): """Model function for CIFAR-10.""" tf.summary.image('images', features, max_outputs=6) network = resnet_model.cifar10_resnet_v2_generator( RESNET_SIZE, config.data_cfg.class_number) inputs = tf.reshape(features, [ -1, config.data_cfg.image_height, config.data_cfg.image_width, config.data_cfg.image_channel ]) logits = network(inputs, mode == tf.estimator.ModeKeys.TRAIN) predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) # Calculate loss, which includes softmax cross entropy and L2 regularization. cross_entropy = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels) # Create a tensor named cross_entropy for logging purposes. tf.identity(cross_entropy, name='cross_entropy') tf.summary.scalar('cross_entropy', cross_entropy) # Add weight decay to the loss. loss = cross_entropy + config.train_cfg.weight_decay * tf.add_n( [tf.nn.l2_loss(v) for v in tf.trainable_variables()]) if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.train.get_or_create_global_step() # Multiply the learning rate by 0.1 at 100, 150, and 200 epochs. _INITIAL_LEARNING_RATE = config.train_cfg.init_lr * config.train_cfg.batch_size / 128 _BATCHES_PER_EPOCH = config.data_cfg.train_number / config.train_cfg.batch_size if config.train_cfg.lr_policy == "lr_step": bound_epochs = range(0, config.train_cfg.train_epochs, config.train_cfg.lr_step.epoch) boundaries = [ int(_BATCHES_PER_EPOCH * epoch) for epoch in bound_epochs ] values = [ _INITIAL_LEARNING_RATE * (config.train_cfg.lr_step.alpha**time) for time in range(0, len(bound_epochs) + 1) ] learning_rate = tf.train.piecewise_constant( tf.cast(global_step, tf.int32), boundaries, values) # Create a tensor named learning_rate for logging purposes tf.identity(learning_rate, name='learning_rate') tf.summary.scalar('learning_rate', learning_rate) optimizer = tf.train.MomentumOptimizer( learning_rate=learning_rate, momentum=config.train_cfg.momentum) # Batch norm requires update ops to be added as a dependency to the train_op update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step) else: train_op = None accuracy = tf.metrics.accuracy(tf.argmax(labels, axis=1), predictions['classes']) metrics = {'accuracy': accuracy} # Create a tensor named train_accuracy for logging purposes tf.identity(accuracy[1], name='train_accuracy') tf.summary.scalar('train_accuracy', accuracy[1]) return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metric_ops=metrics)
def cifar10_model_fn(features, labels, mode, params): """Model function for CIFAR-10.""" tf.summary.image('images', features, max_outputs=6) if FLAGS.network == 'imagenet': network = resnet_model.imagenet_resnet_v2(params['resnet_size'], _NUM_CLASSES, params['data_format']) elif FLAGS.network == 'resnet_dropout': network = resnet_model.cifar10_resnet_dropout_generator( params['resnet_size'], _NUM_CLASSES, params['data_format'], keep_prob=FLAGS.keep_prob) elif FLAGS.network == 'resnet_bottleneck': network = resnet_model.cifar10_resnet_bottleneck_generator( params['resnet_size'], _NUM_CLASSES, params['data_format']) else: network = resnet_model.cifar10_resnet_v2_generator( params['resnet_size'], _NUM_CLASSES, params['data_format']) inputs = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _DEPTH]) logits = network(inputs, mode == tf.estimator.ModeKeys.TRAIN) predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) # Calculate loss, which includes softmax cross entropy and L2 regularization. cross_entropy = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels) # Create a tensor named cross_entropy for logging purposes. tf.identity(cross_entropy, name='cross_entropy') tf.summary.scalar('cross_entropy', cross_entropy) # Add weight decay to the loss. loss = cross_entropy + FLAGS.weight_decay * tf.add_n( [tf.nn.l2_loss(v) for v in tf.trainable_variables()]) if mode == tf.estimator.ModeKeys.TRAIN: # Scale the learning rate linearly with the batch size. When the batch size # is 128, the learning rate should be 0.1. initial_learning_rate = 0.1 * params['batch_size'] / 128 batches_per_epoch = _NUM_IMAGES['train'] / params['batch_size'] global_step = tf.train.get_or_create_global_step() # Multiply the learning rate by 0.1 at 60, 100, 200 and 170 epochs. boundaries = [ int(batches_per_epoch * epoch) for epoch in [60, 100, 150, 170] ] values = [ initial_learning_rate * decay for decay in [1, 0.1, 0.01, 0.001, 0.0001] ] learning_rate = tf.train.piecewise_constant( tf.cast(global_step, tf.int32), boundaries, values) # Create a tensor named learning_rate for logging purposes tf.identity(learning_rate, name='learning_rate') tf.summary.scalar('learning_rate', learning_rate) optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=_MOMENTUM) # Batch norm requires update ops to be added as a dependency to the train_op update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step) else: train_op = None accuracy = tf.metrics.accuracy(tf.argmax(labels, axis=1), predictions['classes']) metrics = {'accuracy': accuracy} # Create a tensor named train_accuracy for logging purposes tf.identity(accuracy[1], name='train_accuracy') tf.summary.scalar('train_accuracy', accuracy[1]) return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metric_ops=metrics)
def model_fn(features, labels, mode): """ Model function for CIFAR-10. For more information: https://www.tensorflow.org/guide/custom_estimators#write_a_model_function """ inputs = features[INPUT_TENSOR_NAME] tf.summary.image('images', inputs, max_outputs=6) network = resnet_model.cifar10_resnet_v2_generator(RESNET_SIZE, NUM_CLASSES) inputs = tf.reshape(inputs, [-1, HEIGHT, WIDTH, DEPTH]) logits = network(inputs, mode == tf.estimator.ModeKeys.TRAIN) predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } if mode == tf.estimator.ModeKeys.PREDICT: export_outputs = { SIGNATURE_NAME: tf.estimator.export.PredictOutput(predictions) } return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, export_outputs=export_outputs) # Calculate loss, which includes softmax cross entropy and L2 regularization. cross_entropy = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=tf.one_hot( labels, 10)) # Create a tensor named cross_entropy for logging purposes. tf.identity(cross_entropy, name='cross_entropy') tf.summary.scalar('cross_entropy', cross_entropy) # Add weight decay to the loss. loss = cross_entropy + _WEIGHT_DECAY * tf.add_n( [tf.nn.l2_loss(v) for v in tf.trainable_variables()]) if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.train.get_or_create_global_step() # Multiply the learning rate by 0.1 at 100, 150, and 200 epochs. boundaries = [ int(_BATCHES_PER_EPOCH * epoch) for epoch in [100, 150, 200] ] values = [ _INITIAL_LEARNING_RATE * decay for decay in [1, 0.1, 0.01, 0.001] ] learning_rate = tf.train.piecewise_constant( tf.cast(global_step, tf.int32), boundaries, values) # Create a tensor named learning_rate for logging purposes tf.identity(learning_rate, name='learning_rate') tf.summary.scalar('learning_rate', learning_rate) optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=_MOMENTUM) # Batch norm requires update ops to be added as a dependency to the train_op update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step) else: train_op = None return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, loss=loss, train_op=train_op)
def train(): ps_hosts = FLAGS.ps_hosts.split(',') worker_hosts = FLAGS.worker_hosts.split(',') print('PS hosts are: %s' % ps_hosts) print('Worker hosts are: %s' % worker_hosts) configP = tf.ConfigProto() server = tf.train.Server({ 'ps': ps_hosts, 'worker': worker_hosts }, job_name=FLAGS.job_name, task_index=FLAGS.task_id, config=configP) batchSizeManager = BatchSizeManager(FLAGS.batch_size, len(worker_hosts)) if FLAGS.job_name == 'ps': rpcServer = batchSizeManager.create_rpc_server( ps_hosts[0].split(':')[0]) rpcServer.serve() server.join() rpcClient = batchSizeManager.create_rpc_client(ps_hosts[0].split(':')[0]) is_chief = (FLAGS.task_id == 0) if is_chief: if tf.gfile.Exists(FLAGS.train_dir): tf.gfile.DeleteRecursively(FLAGS.train_dir) tf.gfile.MakeDirs(FLAGS.train_dir) device_setter = tf.train.replica_device_setter(ps_tasks=len(ps_hosts)) with tf.device('/job:worker/task:%d' % FLAGS.task_id): with tf.device(device_setter): global_step = tf.Variable(0, trainable=False) decay_steps = 50000 * 350.0 / FLAGS.batch_size batch_size = tf.placeholder(dtype=tf.int32, shape=(), name='batch_size') images, labels = cifar10.distorted_inputs(batch_size) re = tf.shape(images)[0] with tf.variable_scope('root', partitioner=tf.fixed_size_partitioner( len(ps_hosts), axis=0)): network = resnet_model.cifar10_resnet_v2_generator( FLAGS.resnet_size, _NUM_CLASSES) inputs = tf.reshape(images, [-1, _HEIGHT, _WIDTH, _DEPTH]) # labels = tf.reshape(labels, [-1, _NUM_CLASSES]) print(labels.get_shape()) labels = tf.one_hot(labels, 10, 1, 0) print(labels.get_shape()) logits = network(inputs, True) print(logits.get_shape()) cross_entropy = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=labels) # logits = cifar10.inference(images, batch_size) # loss = cifar10.loss(logits, labels, batch_size) loss = cross_entropy + _WEIGHT_DECAY * tf.add_n( [tf.nn.l2_loss(v) for v in tf.trainable_variables()]) # Decay the learning rate exponentially based on the number of steps. lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE, global_step, decay_steps, LEARNING_RATE_DECAY_FACTOR, staircase=True) opt = tf.train.GradientDescentOptimizer(lr) # Track the moving averages of all trainable variables. exp_moving_averager = tf.train.ExponentialMovingAverage( MOVING_AVERAGE_DECAY, global_step) variables_to_average = (tf.trainable_variables() + tf.moving_average_variables()) opt = tf.train.SyncReplicasOptimizer( opt, replicas_to_aggregate=len(worker_hosts), # replica_id=FLAGS.task_id, total_num_replicas=len(worker_hosts), variable_averages=exp_moving_averager, variables_to_average=variables_to_average) grads0 = opt.compute_gradients(loss) grads = [(tf.scalar_mul( tf.cast(batch_size / FLAGS.batch_size, tf.float32), grad), var) for grad, var in grads0] apply_gradients_op = opt.apply_gradients(grads, global_step=global_step) with tf.control_dependencies([apply_gradients_op]): train_op = tf.identity(loss, name='train_op') chief_queue_runners = [opt.get_chief_queue_runner()] init_tokens_op = opt.get_init_tokens_op() # saver = tf.train.Saver() sv = tf.train.Supervisor( is_chief=is_chief, logdir=FLAGS.train_dir, init_op=tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()), summary_op=None, global_step=global_step, # saver=saver, saver=None, recovery_wait_secs=1, save_model_secs=60) tf.logging.info('%s Supervisor' % datetime.now()) sess_config = tf.ConfigProto( allow_soft_placement=True, intra_op_parallelism_threads=1, inter_op_parallelism_threads=1, log_device_placement=FLAGS.log_device_placement) sess_config.gpu_options.allow_growth = True # Get a session. sess = sv.prepare_or_wait_for_session(server.target, config=sess_config) # sess.run(tf.global_variables_initializer()) # Start the queue runners. queue_runners = tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS) sv.start_queue_runners(sess, queue_runners) sv.start_queue_runners(sess, chief_queue_runners) sess.run(init_tokens_op) """Train CIFAR-10 for a number of steps.""" time0 = time.time() batch_size_num = FLAGS.batch_size for step in range(FLAGS.max_steps): start_time = time.time() run_options = tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() # batch_size_num = updated_batch_size_num if step <= 5: batch_size_num = FLAGS.batch_size if step >= 0: batch_size_num = int(step / 5) % 512 + 1 batch_size_num = 128 num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / batch_size_num decay_steps_num = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY) # mgrads, images_, train_val, real, loss_value, gs = sess.run([grads, images, train_op, re, loss, global_step], feed_dict={batch_size: batch_size_num}, options=run_options, run_metadata=run_metadata) _, loss_value, gs = sess.run( [train_op, loss, global_step], feed_dict={batch_size: batch_size_num}, options=run_options, run_metadata=run_metadata) # _, loss_value, gs = sess.run([train_op, loss, global_step], feed_dict={batch_size: batch_size_num}) b = time.time() # tl = timeline.Timeline(run_metadata.step_stats) # last_batch_time = tl.get_local_step_duration('sync_token_q_Dequeue') # thread = threading2.Thread(target=get_computation_time, name="get_computation_time",args=(run_metadata.step_stats,step,)) # thread.start() c0 = time.time() # batch_size_num = rpcClient.update_batch_size(FLAGS.task_id, last_batch_time, available_cpu, available_memory, step, batch_size_num) batch_size_num = rpcClient.update_batch_size( FLAGS.task_id, 0, 0, 0, step, batch_size_num) if step % 1 == 0: duration = time.time() - start_time num_examples_per_step = batch_size_num examples_per_sec = num_examples_per_step / duration sec_per_batch = float(duration) c = time.time() ## tf.logging.info("time statistics - batch_process_time: " + str( last_batch_time) + " - train_time: " + str(b-start_time) + " - get_batch_time: " + str(c0-b) + " - get_bs_time: " + str(c-c0) + " - accum_time: " + str(c-time0)) format_str = ( "time: " + str(time.time()) + '; %s: step %d (global_step %d), loss = %.2f (%.1f examples/sec; %.3f sec/batch) - batch_size: ' + str(batch_size_num)) tf.logging.info(format_str % (datetime.now(), step, gs, loss_value, examples_per_sec, sec_per_batch))
def cifar10_model_fn(features, labels, mode, params): """Model function for CIFAR-10.""" tf.summary.image('images', features, max_outputs=6) network = resnet_model.cifar10_resnet_v2_generator( params['resnet_size'], _NUM_CLASSES, params['data_format']) inputs = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _DEPTH]) logits = network(inputs, mode == tf.estimator.ModeKeys.TRAIN) predictions = { 'classes': tf.argmax(logits, axis=1), 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') } if mode == tf.estimator.ModeKeys.PREDICT: return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) # Calculate loss, which includes softmax cross entropy and L2 regularization. cross_entropy = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=labels) # Create a tensor named cross_entropy for logging purposes. tf.identity(cross_entropy, name='cross_entropy') tf.summary.scalar('cross_entropy', cross_entropy) # Add weight decay to the loss. loss = cross_entropy + _WEIGHT_DECAY * tf.add_n( [tf.nn.l2_loss(v) for v in tf.trainable_variables()]) if mode == tf.estimator.ModeKeys.TRAIN: # Scale the learning rate linearly with the batch size. When the batch size # is 128, the learning rate should be 0.1. initial_learning_rate = 0.1 * params['batch_size'] / 128 batches_per_epoch = _NUM_IMAGES['train'] / params['batch_size'] global_step = tf.train.get_or_create_global_step() # Multiply the learning rate by 0.1 at 100, 150, and 200 epochs. boundaries = [int(batches_per_epoch * epoch) for epoch in [100, 150, 200]] values = [initial_learning_rate * decay for decay in [1, 0.1, 0.01, 0.001]] learning_rate = tf.train.piecewise_constant( tf.cast(global_step, tf.int32), boundaries, values) # Create a tensor named learning_rate for logging purposes tf.identity(learning_rate, name='learning_rate') tf.summary.scalar('learning_rate', learning_rate) optimizer = tf.train.MomentumOptimizer( learning_rate=learning_rate, momentum=_MOMENTUM) # Batch norm requires update ops to be added as a dependency to the train_op update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step) else: train_op = None accuracy = tf.metrics.accuracy( tf.argmax(labels, axis=1), predictions['classes']) metrics = {'accuracy': accuracy} # Create a tensor named train_accuracy for logging purposes tf.identity(accuracy[1], name='train_accuracy') tf.summary.scalar('train_accuracy', accuracy[1]) return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, loss=loss, train_op=train_op, eval_metric_ops=metrics)
def cifar10_vibo_model_fn(features, labels, mode, params): """Model function for CIFAR-10.""" _DIM_Z = params["dim_z"] network = resnet_model.cifar10_resnet_v2_generator(params['resnet_size'], _DIM_Z * 2, params['data_format']) logits_from_z = tf.layers.Dense(_NUM_CLASSES) inputs = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _DEPTH]) clabels = labels[:, :_NUM_CLASSES] params_z = network(inputs, mode == tf.estimator.ModeKeys.TRAIN, name="main") mean_z = params_z[:, :_DIM_Z] std_z = tf.nn.softplus(params_z[:, _DIM_Z:]) # _mean_z = tf.expand_dims(mean_z, axis=1) # _std_z = tf.expand_dims(std_z, axis=1) distr_z = tfp.distributions.MultivariateNormalDiag( loc=mean_z, scale_diag=std_z, # loc=tf.broadcast_to(_mean_z, [-1, _NUM_CLASSES, _DIM_Z]), # scale_diag=tf.broadcast_to(_std_z, [-1, _NUM_CLASSES, _DIM_Z]), ) # squeeze the _NUM_CLASSES dim # z_samples = tf.squeeze(distr_z.sample(_NUM_SAMPLES_Z), axis=2) z_samples = distr_z.sample(_NUM_SAMPLES_Z) logits_samples = logits_from_z(z_samples) br_clabels = tf.expand_dims(clabels, axis=0) # br_clabels = tf.broadcast_to(clabels, shape=[_NUM_SAMPLES_Z, -1, _NUM_CLASSES]) br_clabels = tf.broadcast_to(br_clabels, shape=tf.shape(logits_samples)) # mean across samples and batch base_loss = tf.reduce_mean( tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2( logits=logits_samples, labels=br_clabels), axis=0), # tf.reduce_mean(custom_softmax_cross_entropy(logits=logits_samples, labels=br_clabels), axis=0), axis=0, name="base_loss") mean_prior = tf.get_variable("prior_mean", (_DIM_Z)) std_prior = tf.nn.softplus(tf.get_variable("prior_std", (_DIM_Z))) distr_prior_z = tfp.distributions.MultivariateNormalDiag( loc=tf.expand_dims(mean_prior, axis=0), scale_diag=tf.expand_dims(std_prior, axis=0), ) kldivs = tfp.distributions.kl_divergence(distr_z, distr_prior_z) # kldivs_corr = clabels * kldivs # tf.reduce_sum(kldivs_corr, axis=1), kldiv_term = tf.reduce_mean( kldivs, axis=0, ) loss = base_loss + params["lamb"] * kldiv_term loss = tf.identity(loss, name="loss_vec") loss_sum = tf.summary.scalar("loss", loss) # squeeze the _NUM_CLASSES dim sample_z = distr_z.sample() # logits = logits_from_z(tf.squeeze(sample_z, axis=1)) logits = logits_from_z(sample_z) probs = tf.nn.softmax(logits, 1) # ratio of e and m for the a particular input # e_by_m = distr_z.prob(sample_z) / distr_prior_z.prob(sample_z) # sum over num classes # rate = tf.reduce_sum(e_by_m * 1, axis=1, keepdims=True) # kl = 0 implies in distribution and kl->inf implies out of distr # rate = exp(-kl) implies in distr if rate = 1 # rate = tf.exp( # -tf.reduce_sum(kldivs * probs, axis=1) # ) rate = tf.exp(-kldivs) # loss = tf.Print(loss, [sample_z], message="E/M") # loss = tf.Print(loss, [rate], message="Rate") # loss = tf.Print(loss, [distr_z.prob(sample_z), distr_prior_z.prob(sample_z)], message="E/M") classes = tf.argmax(logits, axis=1) accuracy_m = tf.metrics.accuracy(tf.argmax(clabels, axis=1), classes, name="accuracy_metric") accuracy = tf.identity(accuracy_m[1], name="accuracy_vec") accuracy_sum = tf.summary.scalar("accuracy", accuracy) if mode == tf.estimator.ModeKeys.EVAL or params["predict"]: # print # note this is labels not clabels print_labels = tf.argmax(labels, axis=1) print_rate = rate print_probs = probs print_logits = logits hooks = [] eval_metric_ops = {"accuracy": accuracy_m} # # printing stuff if predict if params["predict"]: loss = tf.Print(loss, [print_labels], summarize=1000000, message='Targets') loss = tf.Print(loss, [print_rate], summarize=1000000, message='Rate') loss = tf.Print(loss, [print_probs], summarize=1000000, message='Probs') loss = tf.Print(loss, [print_logits], summarize=1000000, message='Logits') hooks = [] eval_metric_ops = {} return tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=eval_metric_ops, # evaluation_hooks=hooks, ) if mode == tf.estimator.ModeKeys.TRAIN: # Scale the learning rate linearly with the batch size. When the batch size # is 128, the learning rate should be 0.1. initial_learning_rate = 0.1 * params['batch_size'] / 128 batches_per_epoch = _NUM_IMAGES['train'] / params['batch_size'] global_step = tf.train.get_or_create_global_step() # Multiply the learning rate by 0.1 at 100, 150, and 200 epochs. boundaries = [ int(batches_per_epoch * epoch) for epoch in [100, 150, 200] ] values = [ initial_learning_rate * decay for decay in [1, 0.1, 0.01, 0.001] ] learning_rate = tf.train.piecewise_constant(tf.cast( global_step, tf.int32), boundaries, values, name="learning_rate_vec") learning_rate_sum = tf.summary.scalar("learning_rate", learning_rate) optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=_MOMENTUM) # Batch norm requires update ops to be added as a dependency to the train_op update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step) hook = tf.train.SummarySaverHook( summary_op=tf.summary.merge([accuracy_sum, learning_rate_sum]), save_steps=1, ) return tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, training_hooks=[hook], )
def main(unused_argv): # Using the Winograd non-fused algorithms provides a small performance boost. # os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' # Set up a RunConfig to only save checkpoints once per training cycle. # run_config = tf.estimator.RunConfig().replace(save_checkpoints_secs=1e9) # tf.summary.image('images', features, max_outputs=6) # with tf.device('/gpu:3'): network = resnet_model.cifar10_resnet_v2_generator(FLAGS.resnet_size, _NUM_CLASSES) features, labels = input_fn(is_training=True) inputs = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _DEPTH]) # logits = network(inputs, mode == tf.estimator.ModeKeys.TRAIN) print('cc') print(labels.get_shape()) logits = network(inputs, True) # Calculate loss, which includes softmax cross entropy and L2 regularization. cross_entropy = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels) # Create a tensor named cross_entropy for logging purposes. tf.identity(cross_entropy, name='cross_entropy') tf.summary.scalar('cross_entropy', cross_entropy) # Add weight decay to the loss. loss = cross_entropy + _WEIGHT_DECAY * tf.add_n( [tf.nn.l2_loss(v) for v in tf.trainable_variables()]) global_step = tf.train.get_or_create_global_step() # Multiply the learning rate by 0.1 at 100, 150, and 200 epochs. boundaries = [int(_BATCHES_PER_EPOCH * epoch) for epoch in [100, 150, 200]] values = [ _INITIAL_LEARNING_RATE * decay for decay in [1, 0.1, 0.01, 0.001] ] learning_rate = tf.train.piecewise_constant(tf.cast(global_step, tf.int32), boundaries, values) # Create a tensor named learning_rate for logging purposes tf.identity(learning_rate, name='learning_rate') tf.summary.scalar('learning_rate', learning_rate) optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=_MOMENTUM) # Batch norm requires update ops to be added as a dependency to the train_op update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step) # accuracy = tf.metrics.accuracy( # tf.argmax(labels, axis=1), predictions['classes']) # metrics = {'accuracy': accuracy} # Create a tensor named train_accuracy for logging purposes # tf.identity(accuracy[1], name='train_accuracy') # tf.summary.scalar('train_accuracy', accuracy[1]) init = tf.global_variables_initializer() sess = tf.Session(config=tf.ConfigProto(log_device_placement=True)) sess.run(init) for steps in range(1000): _, loss_value, gs = sess.run([train_op, loss, global_step]) print('hey-step ' + str(steps) + '; loss: ' + str(loss_value))
def cifar10_binc_model_fn(features, labels, mode, params): """Model function for CIFAR-10.""" _DIM_Z = params["dim_z"] network = resnet_model.cifar10_resnet_v2_generator(params['resnet_size'], _DIM_Z, params['data_format']) logits_from_z = tf.layers.Dense(_NUM_CLASSES, name="logits_z") confs_from_z = tf.layers.Dense(_NUM_CLASSES, use_bias=False, name="confs_z") inputs = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _DEPTH]) clabels = labels[:, :_NUM_CLASSES] z_space = network(inputs, mode == tf.estimator.ModeKeys.TRAIN, name="main") # z_space = network(inputs, mode == tf.estimator.ModeKeys.TRAIN) z_space = tf.nn.relu(z_space) logits = logits_from_z(z_space) probs = tf.nn.softmax(logits, axis=1) base_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2( logits=logits, labels=clabels), name="base_loss") confs = confs_from_z(z_space) confs = tf.sigmoid(confs) # slabels, smask = smooth_neg_labels(clabels, params["cutoff_weight"], params["pen_prob"]) # ct_loss = tf.reduce_mean(custom_cross_entropy(confs, slabels)) ct_loss = tf.reduce_mean(per_class_bin_loss(confs, clabels, params["milden"]), axis=[0, 1]) # cc_term = tf.reduce_sum(per_class_mmce_loss(confs, clabels), axis=0) cc_term = per_class_mmce_loss(confs, clabels) loss = base_loss + params["lamb"] * (ct_loss + params["mmcec"] * cc_term) loss = tf.identity(loss, name="loss_vec") loss_sum = tf.summary.scalar("loss", loss) rate = tf.reduce_max(confs, axis=1) # summaries conf_hist_sum = tf.summary.histogram("confidence", confs) _bshape = tf.shape(rate) # construct binary labels blabels = tf.logical_not( tf.equal( tf.argmax(labels, axis=1), _NUM_CLASSES, )) # calibration graph cpreds = tf.greater(rate, 0.5) crates = tf.where( cpreds, rate, 1 - rate, ) cps = tf.to_float(tf.logical_and(cpreds, blabels), ) rate_cal = cps * crates rate_cal_sum = tf.summary.histogram("accur_confidence", rate_cal) # loss = tf.Print(loss, [smask], summarize=100, message="smask: ") # loss = tf.Print(loss, [tf.reduce_mean(probs)], summarize=100, message="mean: ") # loss = tf.Print(loss, [rate], summarize=100, message="rate: ") # loss = tf.Print(loss, [clabels, slabels], summarize=100, message="slabels: ") classes = tf.argmax(logits, axis=1) accuracy_m = tf.metrics.accuracy(tf.argmax(clabels, axis=1), classes, name="accuracy_metric") accuracy = tf.identity(accuracy_m[1], name="accuracy_vec") accuracy_sum = tf.summary.scalar("accuracy", accuracy) if mode == tf.estimator.ModeKeys.EVAL or params["predict"]: # print # note this is labels not clabels print_labels = tf.argmax(labels, axis=1) print_rate = rate print_confs = confs print_probs = probs print_logits = logits # hooks = [tf.train.SummarySaverHook( # summary_op=tf.summary.merge([accuracy_sum]), # output_dir=os.path.join(params["model_dir"], "eval_core"), # save_steps=1, # )] hooks = [] eval_metric_ops = {"accuracy": accuracy_m} # # printing stuff if predict if params["predict"]: loss = tf.Print(loss, [print_labels], summarize=1000000, message='Targets') loss = tf.Print(loss, [print_rate], summarize=1000000, message='Rate') loss = tf.Print(loss, [print_confs], summarize=1000000, message='Confs') loss = tf.Print(loss, [print_probs], summarize=1000000, message='Probs') loss = tf.Print(loss, [print_logits], summarize=1000000, message='Logits') hooks = [] eval_metric_ops = {} return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops # evaluation_hooks=hooks, ) if mode == tf.estimator.ModeKeys.TRAIN: # Scale the learning rate linearly with the batch size. When the batch size # is 128, the learning rate should be 0.1. initial_learning_rate = 0.1 * params['batch_size'] / 128 batches_per_epoch = _NUM_IMAGES['train'] / params['batch_size'] global_step = tf.train.get_or_create_global_step() # Multiply the learning rate by 0.1 at 100, 150, and 200 epochs. boundaries = [ int(batches_per_epoch * epoch) for epoch in [100, 150, 200] ] values = [ initial_learning_rate * decay for decay in [1, 0.1, 0.01, 0.001] ] learning_rate = tf.train.piecewise_constant(tf.cast( global_step, tf.int32), boundaries, values, name="learning_rate_vec") learning_rate_sum = tf.summary.scalar("learning_rate", learning_rate) optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=_MOMENTUM) # Batch norm requires update ops to be added as a dependency to the train_op update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, global_step) hook = tf.train.SummarySaverHook( summary_op=tf.summary.merge( [accuracy_sum, learning_rate_sum, conf_hist_sum, rate_cal_sum]), save_steps=1, ) return tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, training_hooks=[hook], )