def test_linear_decay(self): global_step = tf.placeholder(dtype=tf.int32, shape=()) decay_value = custom_layers.linear_decay(global_step, 10) with self.session() as sess: self.assertAllClose(sess.run(decay_value, {global_step: -500}), 1.0) self.assertAllClose(sess.run(decay_value, {global_step: 0}), 1.0) self.assertAllClose(sess.run(decay_value, {global_step: 5}), 0.5) self.assertAllClose(sess.run(decay_value, {global_step: 9}), 0.1) self.assertAllClose(sess.run(decay_value, {global_step: 10}), 0) self.assertAllClose(sess.run(decay_value, {global_step: 50}), 0) self.assertAllClose(sess.run(decay_value, {global_step: 5000000}), 0)
def model_fn(features, labels, mode, params): """Construct a TPUEstimatorSpec for a model.""" if mode != tf.estimator.ModeKeys.TRAIN: raise NotImplementedError( 'Expected that mode == TRAIN, but got {:!r}'.format(mode)) # Data was transposed from NHWC to HWCN on the host side. Transpose it back. # This transposition will be optimized away by the XLA compiler. It serves # as a hint to the compiler that it should expect the input data to come # in HWCN format rather than NHWC. train_features = tf.transpose(features['train'], [3, 0, 1, 2]) validation_features = tf.transpose(features['validation'], [3, 0, 1, 2]) if params['use_bfloat16'] == 'ontpu': train_features = tf.cast(train_features, tf.bfloat16) validation_features = tf.cast(validation_features, tf.bfloat16) global_step = tf.train.get_global_step() # Randomly sample a network architecture. with tf.variable_scope('rl_controller') as rl_scope: pass model_spec = mobile_classifier_factory.get_model_spec(params['ssd']) tf.io.gfile.makedirs(params['checkpoint_dir']) model_spec_filename = os.path.join(params['checkpoint_dir'], 'model_spec.json') with tf.io.gfile.GFile(model_spec_filename, 'w') as handle: handle.write(schema_io.serialize(model_spec)) increase_ops_prob = custom_layers.linear_decay( global_step, params['increase_ops_warmup_steps']) increase_filters_prob = custom_layers.linear_decay( global_step, params['increase_filters_warmup_steps']) model_spec, dist_info = controller.independent_sample( model_spec, increase_ops_probability=increase_ops_prob, increase_filters_probability=increase_filters_prob, name=rl_scope) if params['enable_cost_model']: cost_model_features = mobile_cost_model.coupled_tf_features(model_spec) estimated_cost = cost_model_lib.estimate_cost(cost_model_features, params['ssd']) # We divide the regularization strength by 2 for backwards compatibility with # the deprecated tf.contrib.layers.l2_regularizer() function, which was used # in our published experiments. kernel_regularizer = tf.keras.regularizers.l2( params['model_weight_decay'] / 2) # Set up the basic TensorFlow training/inference graph. model = mobile_classifier_factory.get_model_for_search( model_spec, kernel_regularizer=kernel_regularizer) model.build(train_features.shape) with tf.name_scope('training'): model_logits, _ = model.apply(train_features, training=True) # Cast back to float32 (effectively only when using use_bfloat16 is true). model_logits = tf.cast(model_logits, tf.float32) model_empirical_loss = tf.losses.softmax_cross_entropy( onehot_labels=labels['train'], logits=model_logits, label_smoothing=0.1) model_regularization_loss = model.regularization_loss() model_loss = model_empirical_loss + model_regularization_loss # Set up the model weight training logic. model_learning_rate = custom_layers.cosine_decay_with_linear_warmup( peak_learning_rate=params['model_learning_rate'], global_step=global_step, max_global_step=params['max_global_step'], warmup_steps=params['model_warmup_steps']) model_optimizer = tf.tpu.CrossShardOptimizer( tf.train.RMSPropOptimizer(model_learning_rate, decay=0.9, momentum=params['model_momentum'], epsilon=1.0)) model_vars = model.trainable_variables() model_update_ops = model.updates() with tf.control_dependencies(model_update_ops): grads_and_vars = model_optimizer.compute_gradients( model_loss, var_list=model_vars) if params['use_gradient_sync_barrier']: # Force all gradients to be computed before any are applied. grads_and_vars = _grads_and_vars_barrier(grads_and_vars) # NOTE: We do not pass `global_step` to apply_gradients(), so the global # step is not incremented by `model_optimizer`. The global_step will be # incremented later on, when we update the RL controller weights. If we # incremented it here too, we'd end up incrementing the global_step twice # at each training step. model_op = model_optimizer.apply_gradients(grads_and_vars) if params['use_gradient_sync_barrier']: # Finish computing gradients for the shared model weights before we # start on the RL update step. # # NOTE: The barrier above forces TensorFlow to finish computing grads # for all of the trainable variables before any of the grads can be # consumed. So while the call to with_data_dependencies() here only # explicitly depends on grads_and_vars[0][0], the call implicitly forces # TensorFlow to finish computing the gradients for *all* trainable # variables before computing the validation features. validation_features = layers.with_data_dependencies( [grads_and_vars[0][0]], [validation_features])[0] with tf.name_scope('validation'): # Estimate the model accuracy on a batch of examples from the validation # set. Force this logic to run after the model optimization step. with tf.control_dependencies([model_op]): validation_logits, _ = model.apply(validation_features, training=False) # NOTE(b/130311965): An earlier version of this code cast validation_logits # from bfloat16 to float32 before applying an argmax when the --use_bfloat16 # flag was true. As of cl/240923609, this caused XLA to compute incorrect # model accuracies. Please avoid casting from bfloat16 to bfloat32 before # taking the argmax. is_prediction_correct = tf.equal( tf.argmax(validation_logits, axis=1), tf.argmax(labels['validation'], axis=1)) validation_accuracy = tf.reduce_mean( tf.cast(is_prediction_correct, tf.float32)) # Estimate the reward for the current network architecture and update the # reward to incorporate the cost of the network architecture. if params['enable_cost_model']: rl_stats = search_space_utils.reward_for_single_cost_model( validation_accuracy, rl_reward_function=params['rl_reward_function'], estimated_cost=estimated_cost, rl_cost_model_target=params['rl_cost_model_target'], rl_cost_model_exponent=params['rl_cost_model_exponent']) rl_cost_ratio = rl_stats['rl_cost_ratio'] rl_reward = rl_stats['rl_reward'] rl_cost_adjustment = rl_stats['rl_cost_adjustment'] else: rl_reward = validation_accuracy # Compute a baseline. We first take a cross-replica sum of the rewards # for all the TPU shards, then incorporate the result into an exponential # moving average. Within a single batch, each TPU shard will select a # different set of op masks from the RL controller. Each shard will basically # evaluate a different candidate architecture in our search space. # Count the number of TPU shards (cores) used for training. num_tpu_shards = tf.tpu.cross_replica_sum( tf.ones(shape=(), dtype=rl_reward.dtype)) rl_step_baseline = tf.tpu.cross_replica_sum(rl_reward) rl_step_baseline = rl_step_baseline / num_tpu_shards rl_baseline = custom_layers.update_exponential_moving_average( rl_step_baseline, momentum=params['rl_baseline_momentum']) # Apply a REINFORCE update to the RL controller. log_prob = dist_info['sample_log_prob'] rl_advantage = rl_reward - rl_baseline rl_empirical_loss = -tf.stop_gradient(rl_advantage) * log_prob # We set rl_entropy_loss proportional to (-entropy) so that minimizing the # loss will lead to an entropy that is as large as possible. rl_entropy = dist_info['entropy'] rl_entropy_loss = -params['rl_entropy_regularization'] * rl_entropy # We use an RL learning rate of 0 for the first N epochs of training. See # Appendix A of FBNet. (https://arxiv.org/pdf/1812.03443.pdf). Although they # don't mention it explicitly, there are some indications that ProxylessNAS # (https://openreview.net/forum?id=HylVB3AqYm) might also be doing this. enable_rl_optimizer = tf.cast( tf.greater_equal(global_step, params['rl_delay_steps']), tf.float32) rl_learning_rate = params['rl_learning_rate'] * enable_rl_optimizer if params['use_exponential_rl_learning_rate_schedule']: # rl_learning_rate_progress will be 0 when the RL controller starts # learning and 1 when the search ends. rl_learning_rate_progress = tf.nn.relu( tf.div( tf.cast(global_step - params['rl_delay_steps'], tf.float32), max(1, params['max_global_step'] - params['rl_delay_steps']))) # exponentially increase the RL learning rate over time. rl_learning_rate_multiplier = tf.pow(10.0, rl_learning_rate_progress) rl_learning_rate = rl_learning_rate * rl_learning_rate_multiplier rl_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, rl_scope.name) with tf.control_dependencies(rl_update_ops): # In order to evaluate train_op, we must first evaluate validation_accuracy. # And to evaluate validation_accuracy, we must first evaluate model_op. So # running this op will perform a step of model training followed by # a step of RL controller training. if params['use_gradient_sync_barrier']: transform_grads_fn = _grads_and_vars_barrier else: transform_grads_fn = None train_op = tpu_optimizer_ops.apply_adam( rl_empirical_loss, regularization_loss=rl_entropy_loss, global_step=global_step, var_list=tf.trainable_variables(rl_scope.name), learning_rate=rl_learning_rate, beta1=0.0, beta2=0.999, epsilon=1e-8, transform_grads_fn=transform_grads_fn) # TensorBoard logging tensorboard_scalars = collections.OrderedDict([ ('model/loss', model_loss), ('model/empirical_loss', model_empirical_loss), ('model/regularization_loss', model_regularization_loss), ('model/learning_rate', model_learning_rate), ('rlcontroller/empirical_loss', rl_empirical_loss), ('rlcontroller/entropy_loss', rl_entropy_loss), ('rlcontroller/validation_accuracy', validation_accuracy), ('rlcontroller/reward', rl_reward), ('rlcontroller/step_baseline', rl_step_baseline), ('rlcontroller/baseline', rl_baseline), ('rlcontroller/advantage', rl_advantage), ('rlcontroller/log_prob', log_prob), ]) if params['enable_cost_model']: tensorboard_scalars['rlcontroller/estimated_cost'] = estimated_cost tensorboard_scalars['rlcontroller/cost_ratio'] = rl_cost_ratio tensorboard_scalars[ 'rlcontroller/cost_adjustment'] = rl_cost_adjustment tensorboard_scalars['rlcontroller/learning_rate'] = rl_learning_rate tensorboard_scalars['rlcontroller/increase_ops_prob'] = increase_ops_prob tensorboard_scalars['rlcontroller/increase_filters_prob'] = ( increase_filters_prob) # Log the values of all the choices made by the RL controller. for name_i, logits_i in dist_info['logits_by_path'].items(): assert len(logits_i.shape) == 1, logits_i for j in range(int(logits_i.shape[0])): key = 'rlpathlogits/{:s}/{:d}'.format(name_i, j) tensorboard_scalars[key] = logits_i[j] for name_i, logits_i in dist_info['logits_by_tag'].items(): assert len(logits_i.shape) == 1, logits_i for j in range(int(logits_i.shape[0])): key = 'rltaglogits/{:s}/{:d}'.format(name_i, j) tensorboard_scalars[key] = logits_i[j] # NOTE: host_call only works on rank-1 tensors. There's also a fairly # large performance penalty if we try to pass too many distinct tensors # from the TPU to the host at once. We avoid these problems by (i) calling # tf.stack to merge all of the float32 scalar values into a single rank-1 # tensor that can be sent to the host relatively cheaply and (ii) reshaping # the remaining values from scalars to rank-1 tensors. def host_call_fn(step, scalar_values): values = tf.unstack(scalar_values) with tf2.summary.create_file_writer( params['checkpoint_dir']).as_default(): with tf2.summary.record_if( tf.math.equal(step[0] % params['tpu_iterations_per_loop'], 0)): for key, value in zip(list(tensorboard_scalars.keys()), values): tf2.summary.scalar(key, value, step=step[0]) return tf.summary.all_v2_summary_ops() host_call_values = tf.stack(list(tensorboard_scalars.values())) host_call = (host_call_fn, [tf.reshape(global_step, [1]), host_call_values]) # Construct the estimator specification. return tf.estimator.tpu.TPUEstimatorSpec(mode=mode, loss=model_loss, train_op=train_op, host_call=host_call)