Ejemplo n.º 1
0
def _grads_and_vars_barrier(
    grads_and_vars
):
  """Barrier that forces all grads to be computed before any are used."""
  current_grads, current_vars = list(zip(*grads_and_vars))
  current_grads = layers.with_data_dependencies(current_grads, current_grads)
  return list(zip(current_grads, current_vars))
Ejemplo n.º 2
0
def model_fn(features, labels, mode, params):
    """Construct a TPUEstimatorSpec for a model."""
    if mode != tf.estimator.ModeKeys.TRAIN:
        raise NotImplementedError(
            'Expected that mode == TRAIN, but got {:!r}'.format(mode))

    # Data was transposed from NHWC to HWCN on the host side. Transpose it back.
    # This transposition will be optimized away by the XLA compiler. It serves
    # as a hint to the compiler that it should expect the input data to come
    # in HWCN format rather than NHWC.
    train_features = tf.transpose(features['train'], [3, 0, 1, 2])
    validation_features = tf.transpose(features['validation'], [3, 0, 1, 2])

    if params['use_bfloat16'] == 'ontpu':
        train_features = tf.cast(train_features, tf.bfloat16)
        validation_features = tf.cast(validation_features, tf.bfloat16)

    global_step = tf.train.get_global_step()

    # Randomly sample a network architecture.
    with tf.variable_scope('rl_controller') as rl_scope:
        pass

    model_spec = mobile_classifier_factory.get_model_spec(params['ssd'])

    tf.io.gfile.makedirs(params['checkpoint_dir'])
    model_spec_filename = os.path.join(params['checkpoint_dir'],
                                       'model_spec.json')
    with tf.io.gfile.GFile(model_spec_filename, 'w') as handle:
        handle.write(schema_io.serialize(model_spec))

    increase_ops_prob = custom_layers.linear_decay(
        global_step, params['increase_ops_warmup_steps'])
    increase_filters_prob = custom_layers.linear_decay(
        global_step, params['increase_filters_warmup_steps'])
    model_spec, dist_info = controller.independent_sample(
        model_spec,
        increase_ops_probability=increase_ops_prob,
        increase_filters_probability=increase_filters_prob,
        name=rl_scope)

    if params['enable_cost_model']:
        cost_model_features = mobile_cost_model.coupled_tf_features(model_spec)
        estimated_cost = cost_model_lib.estimate_cost(cost_model_features,
                                                      params['ssd'])

    # We divide the regularization strength by 2 for backwards compatibility with
    # the deprecated tf.contrib.layers.l2_regularizer() function, which was used
    # in our published experiments.
    kernel_regularizer = tf.keras.regularizers.l2(
        params['model_weight_decay'] / 2)

    # Set up the basic TensorFlow training/inference graph.
    model = mobile_classifier_factory.get_model_for_search(
        model_spec, kernel_regularizer=kernel_regularizer)
    model.build(train_features.shape)

    with tf.name_scope('training'):
        model_logits, _ = model.apply(train_features, training=True)
        # Cast back to float32 (effectively only when using use_bfloat16 is true).
        model_logits = tf.cast(model_logits, tf.float32)

        model_empirical_loss = tf.losses.softmax_cross_entropy(
            onehot_labels=labels['train'],
            logits=model_logits,
            label_smoothing=0.1)
        model_regularization_loss = model.regularization_loss()
        model_loss = model_empirical_loss + model_regularization_loss

        # Set up the model weight training logic.
        model_learning_rate = custom_layers.cosine_decay_with_linear_warmup(
            peak_learning_rate=params['model_learning_rate'],
            global_step=global_step,
            max_global_step=params['max_global_step'],
            warmup_steps=params['model_warmup_steps'])

        model_optimizer = tf.tpu.CrossShardOptimizer(
            tf.train.RMSPropOptimizer(model_learning_rate,
                                      decay=0.9,
                                      momentum=params['model_momentum'],
                                      epsilon=1.0))

        model_vars = model.trainable_variables()
        model_update_ops = model.updates()
        with tf.control_dependencies(model_update_ops):
            grads_and_vars = model_optimizer.compute_gradients(
                model_loss, var_list=model_vars)
            if params['use_gradient_sync_barrier']:
                # Force all gradients to be computed before any are applied.
                grads_and_vars = _grads_and_vars_barrier(grads_and_vars)

            # NOTE: We do not pass `global_step` to apply_gradients(), so the global
            # step is not incremented by `model_optimizer`. The global_step will be
            # incremented later on, when we update the RL controller weights. If we
            # incremented it here too, we'd end up incrementing the global_step twice
            # at each training step.
            model_op = model_optimizer.apply_gradients(grads_and_vars)
            if params['use_gradient_sync_barrier']:
                # Finish computing gradients for the shared model weights before we
                # start on the RL update step.
                #
                # NOTE: The barrier above forces TensorFlow to finish computing grads
                # for all of the trainable variables before any of the grads can be
                # consumed. So while the call to with_data_dependencies() here only
                # explicitly depends on grads_and_vars[0][0], the call implicitly forces
                # TensorFlow to finish computing the gradients for *all* trainable
                # variables before computing the validation features.
                validation_features = layers.with_data_dependencies(
                    [grads_and_vars[0][0]], [validation_features])[0]

    with tf.name_scope('validation'):
        # Estimate the model accuracy on a batch of examples from the validation
        # set. Force this logic to run after the model optimization step.
        with tf.control_dependencies([model_op]):
            validation_logits, _ = model.apply(validation_features,
                                               training=False)

        # NOTE(b/130311965): An earlier version of this code cast validation_logits
        # from bfloat16 to float32 before applying an argmax when the --use_bfloat16
        # flag was true. As of cl/240923609, this caused XLA to compute incorrect
        # model accuracies. Please avoid casting from bfloat16 to bfloat32 before
        # taking the argmax.
        is_prediction_correct = tf.equal(
            tf.argmax(validation_logits, axis=1),
            tf.argmax(labels['validation'], axis=1))
        validation_accuracy = tf.reduce_mean(
            tf.cast(is_prediction_correct, tf.float32))

    # Estimate the reward for the current network architecture and update the
    # reward to incorporate the cost of the network architecture.
    if params['enable_cost_model']:
        rl_stats = search_space_utils.reward_for_single_cost_model(
            validation_accuracy,
            rl_reward_function=params['rl_reward_function'],
            estimated_cost=estimated_cost,
            rl_cost_model_target=params['rl_cost_model_target'],
            rl_cost_model_exponent=params['rl_cost_model_exponent'])
        rl_cost_ratio = rl_stats['rl_cost_ratio']
        rl_reward = rl_stats['rl_reward']
        rl_cost_adjustment = rl_stats['rl_cost_adjustment']
    else:
        rl_reward = validation_accuracy

    # Compute a baseline. We first take a cross-replica sum of the rewards
    # for all the TPU shards, then incorporate the result into an exponential
    # moving average. Within a single batch, each TPU shard will select a
    # different set of op masks from the RL controller. Each shard will basically
    # evaluate a different candidate architecture in our search space.

    # Count the number of TPU shards (cores) used for training.
    num_tpu_shards = tf.tpu.cross_replica_sum(
        tf.ones(shape=(), dtype=rl_reward.dtype))
    rl_step_baseline = tf.tpu.cross_replica_sum(rl_reward)
    rl_step_baseline = rl_step_baseline / num_tpu_shards

    rl_baseline = custom_layers.update_exponential_moving_average(
        rl_step_baseline, momentum=params['rl_baseline_momentum'])

    # Apply a REINFORCE update to the RL controller.
    log_prob = dist_info['sample_log_prob']
    rl_advantage = rl_reward - rl_baseline
    rl_empirical_loss = -tf.stop_gradient(rl_advantage) * log_prob

    # We set rl_entropy_loss proportional to (-entropy) so that minimizing the
    # loss will lead to an entropy that is as large as possible.
    rl_entropy = dist_info['entropy']
    rl_entropy_loss = -params['rl_entropy_regularization'] * rl_entropy

    # We use an RL learning rate of 0 for the first N epochs of training. See
    # Appendix A of FBNet. (https://arxiv.org/pdf/1812.03443.pdf). Although they
    # don't mention it explicitly, there are some indications that ProxylessNAS
    # (https://openreview.net/forum?id=HylVB3AqYm) might also be doing this.
    enable_rl_optimizer = tf.cast(
        tf.greater_equal(global_step, params['rl_delay_steps']), tf.float32)
    rl_learning_rate = params['rl_learning_rate'] * enable_rl_optimizer

    if params['use_exponential_rl_learning_rate_schedule']:
        #  rl_learning_rate_progress will be 0 when the RL controller starts
        #  learning and 1 when the search ends.
        rl_learning_rate_progress = tf.nn.relu(
            tf.div(
                tf.cast(global_step - params['rl_delay_steps'], tf.float32),
                max(1, params['max_global_step'] - params['rl_delay_steps'])))
        # exponentially increase the RL learning rate over time.
        rl_learning_rate_multiplier = tf.pow(10.0, rl_learning_rate_progress)
        rl_learning_rate = rl_learning_rate * rl_learning_rate_multiplier

    rl_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, rl_scope.name)
    with tf.control_dependencies(rl_update_ops):
        # In order to evaluate train_op, we must first evaluate validation_accuracy.
        # And to evaluate validation_accuracy, we must first evaluate model_op. So
        # running this op will perform a step of model training followed by
        # a step of RL controller training.
        if params['use_gradient_sync_barrier']:
            transform_grads_fn = _grads_and_vars_barrier
        else:
            transform_grads_fn = None

        train_op = tpu_optimizer_ops.apply_adam(
            rl_empirical_loss,
            regularization_loss=rl_entropy_loss,
            global_step=global_step,
            var_list=tf.trainable_variables(rl_scope.name),
            learning_rate=rl_learning_rate,
            beta1=0.0,
            beta2=0.999,
            epsilon=1e-8,
            transform_grads_fn=transform_grads_fn)

    # TensorBoard logging
    tensorboard_scalars = collections.OrderedDict([
        ('model/loss', model_loss),
        ('model/empirical_loss', model_empirical_loss),
        ('model/regularization_loss', model_regularization_loss),
        ('model/learning_rate', model_learning_rate),
        ('rlcontroller/empirical_loss', rl_empirical_loss),
        ('rlcontroller/entropy_loss', rl_entropy_loss),
        ('rlcontroller/validation_accuracy', validation_accuracy),
        ('rlcontroller/reward', rl_reward),
        ('rlcontroller/step_baseline', rl_step_baseline),
        ('rlcontroller/baseline', rl_baseline),
        ('rlcontroller/advantage', rl_advantage),
        ('rlcontroller/log_prob', log_prob),
    ])

    if params['enable_cost_model']:
        tensorboard_scalars['rlcontroller/estimated_cost'] = estimated_cost
        tensorboard_scalars['rlcontroller/cost_ratio'] = rl_cost_ratio
        tensorboard_scalars[
            'rlcontroller/cost_adjustment'] = rl_cost_adjustment
        tensorboard_scalars['rlcontroller/learning_rate'] = rl_learning_rate

    tensorboard_scalars['rlcontroller/increase_ops_prob'] = increase_ops_prob
    tensorboard_scalars['rlcontroller/increase_filters_prob'] = (
        increase_filters_prob)

    # Log the values of all the choices made by the RL controller.
    for name_i, logits_i in dist_info['logits_by_path'].items():
        assert len(logits_i.shape) == 1, logits_i
        for j in range(int(logits_i.shape[0])):
            key = 'rlpathlogits/{:s}/{:d}'.format(name_i, j)
            tensorboard_scalars[key] = logits_i[j]

    for name_i, logits_i in dist_info['logits_by_tag'].items():
        assert len(logits_i.shape) == 1, logits_i
        for j in range(int(logits_i.shape[0])):
            key = 'rltaglogits/{:s}/{:d}'.format(name_i, j)
            tensorboard_scalars[key] = logits_i[j]

    # NOTE: host_call only works on rank-1 tensors. There's also a fairly
    # large performance penalty if we try to pass too many distinct tensors
    # from the TPU to the host at once. We avoid these problems by (i) calling
    # tf.stack to merge all of the float32 scalar values into a single rank-1
    # tensor that can be sent to the host relatively cheaply and (ii) reshaping
    # the remaining values from scalars to rank-1 tensors.
    def host_call_fn(step, scalar_values):
        values = tf.unstack(scalar_values)
        with tf2.summary.create_file_writer(
                params['checkpoint_dir']).as_default():
            with tf2.summary.record_if(
                    tf.math.equal(step[0] % params['tpu_iterations_per_loop'],
                                  0)):
                for key, value in zip(list(tensorboard_scalars.keys()),
                                      values):
                    tf2.summary.scalar(key, value, step=step[0])
                return tf.summary.all_v2_summary_ops()

    host_call_values = tf.stack(list(tensorboard_scalars.values()))
    host_call = (host_call_fn,
                 [tf.reshape(global_step, [1]), host_call_values])

    # Construct the estimator specification.
    return tf.estimator.tpu.TPUEstimatorSpec(mode=mode,
                                             loss=model_loss,
                                             train_op=train_op,
                                             host_call=host_call)