def create_logits_spec(self,
                         phoenix_spec,
                         pre_logits,
                         dimension,
                         is_frozen,
                         lengths=None):
    """Creates the logits for the tower.

    Args:
      phoenix_spec: The trial's `phoenix_spec_pb2.PhoenixSpec` proto.
      pre_logits: `tf.Tensor` of the layer before the logits layer.
      dimension: int - the output tensor last axis dimension.
      is_frozen: Whether the tower should be frozen.
      lengths: A tensor of shape [batch] holding the sequence length for a
        sequential problem (rnn).

    Returns:
      A LogitsSpec containing the main and auxiliary logits and the architecture
      of the underlying tower.
    """

    logits_weight = 1.0
    aux_logits = None
    aux_logits_weight = None
    if (phoenix_spec.problem_type ==
        phoenix_spec_pb2.PhoenixSpec.RNN_ALL_ACTIVATIONS):
      logits = tf.compat.v1.layers.conv1d(
          inputs=pre_logits, filters=dimension, kernel_size=1)
    elif (phoenix_spec.problem_type ==
          phoenix_spec_pb2.PhoenixSpec.RNN_LAST_ACTIVATIONS):
      if lengths is not None:
        logits = utils.last_activations_in_sequence(
            tf.compat.v1.layers.conv1d(
                inputs=pre_logits, filters=dimension, kernel_size=1), lengths)
      else:
        logging.warning("Length is missing for rnn_last problem type.")
        logits = tf.compat.v1.layers.conv1d(
            inputs=pre_logits, filters=dimension, kernel_size=1)
    elif phoenix_spec.problem_type in (phoenix_spec_pb2.PhoenixSpec.CNN,
                                       phoenix_spec_pb2.PhoenixSpec.DNN):
      logits = tf.keras.layers.Dense(dimension, name="dense")(pre_logits)
    else:
      raise ValueError("phoenix_spec.problem_type must be either DNN, CNN, "
                       "RNN_LAST_ACTIVATIONS, or RNN_ALL_ACTIVATIONS.")

    logits = tf.identity(logits, name="logits")
    if aux_logits is not None:
      aux_logits = tf.identity(aux_logits, name="aux_logits")

    # TODO(b/172564129): Remove from eval graph.
    if is_frozen:
      logits = tf.stop_gradient(logits)
      if aux_logits is not None:
        aux_logits = tf.stop_gradient(aux_logits)

    return architecture_utils.LogitsSpec(logits, logits_weight, aux_logits,
                                         aux_logits_weight)
Beispiel #2
0
 def test_last_activations_in_sequence_with_none(self):
   # Force graph mode
   with tf.compat.v1.Graph().as_default():
     input_tensor = tf.constant(INPUT_TENSOR)
     input_tensor = tf.expand_dims(input_tensor, axis=0)
     batch = tf.tile(input_tensor, tf.constant([5, 1, 1]))
     logging.info(batch)
     output = utils.last_activations_in_sequence(batch)
     with self.test_session() as sess:
       output = sess.run(output)
     self.assertAllEqual(output,
                         np.array([[9, 9], [9, 9], [9, 9], [9, 9], [9, 9]]))
def create_tower_spec(phoenix_spec,
                      inputs,
                      architecture,
                      dimension,
                      is_frozen,
                      lengths=None,
                      allow_auxiliary_head=False):
    """Creates the logits for the tower.

  Args:
    phoenix_spec: The trial's `phoenix_spec_pb2.PhoenixSpec` proto.
    inputs: The list of `tf.Tensors` of the tower.
    architecture: The list of `blocks.BlockType` of the tower architecture.
    dimension: int - the output tensor last axis dimension.
    is_frozen: Whether the tower should be frozen.
    lengths: A tensor of shape [batch] holding the sequence length for a
      sequential problem (rnn).
    allow_auxiliary_head: Whether to allow creating an auxiliary head if
      possible. Only applicable for CNNs.

  Returns:
    A LogitsSpec containing the main and auxiliary logits and the architecture
    of the underlying tower.
  """

    # Discard inputs[0] since this is the raw features.
    all_layer_tensors = inputs
    pre_logits = inputs[-1]
    logits_weight = 1.0
    aux_logits = None
    aux_logits_weight = None
    if (phoenix_spec.problem_type ==
            phoenix_spec_pb2.PhoenixSpec.RNN_ALL_ACTIVATIONS):
        logits = tf.compat.v1.layers.conv1d(inputs=pre_logits,
                                            filters=dimension,
                                            kernel_size=1)
    elif (phoenix_spec.problem_type ==
          phoenix_spec_pb2.PhoenixSpec.RNN_LAST_ACTIVATIONS):
        if lengths is not None:
            logits = utils.last_activations_in_sequence(
                tf.compat.v1.layers.conv1d(inputs=pre_logits,
                                           filters=dimension,
                                           kernel_size=1), lengths)
        else:
            logging.warning("Length is missing for rnn_last problem type.")
            logits = tf.compat.v1.layers.conv1d(inputs=pre_logits,
                                                filters=dimension,
                                                kernel_size=1)
    elif phoenix_spec.problem_type == phoenix_spec_pb2.PhoenixSpec.CNN:
        logits = tf.keras.layers.Dense(dimension, name="dense")(pre_logits)
        if allow_auxiliary_head and phoenix_spec.use_auxiliary_head:
            reductions = []
            flattens = []
            for i, block in enumerate(architecture):
                name = blocks.BlockType(block).name
                if "DOWNSAMPLE" in name or "REDUCTION" in name:
                    reductions.append(i)
                # Some blocks reduce and flatten.
                if "FLATTEN" in name:
                    flattens.append(i)
            if reductions:
                # Add the auxiliary head right before the reduction cell.
                idx = reductions[-1]
                aux_logits = _build_nas_aux_head(inputs[idx], dimension,
                                                 phoenix_spec.cnn_data_format)
                if aux_logits is not None:
                    aux_logits_weight = phoenix_spec.auxiliary_head_loss_weight
            if flattens and aux_logits is None:
                idx = flattens[-1]
                aux_logits = tf.keras.layers.Dense(
                    dimension, name="aux_dense")(inputs[idx])
                aux_logits_weight = phoenix_spec.auxiliary_head_loss_weight
    elif phoenix_spec.problem_type == phoenix_spec_pb2.PhoenixSpec.DNN:
        logits = tf.keras.layers.Dense(dimension, name="dense")(pre_logits)
    else:
        raise ValueError("phoenix_spec.problem_type must be either DNN, CNN, "
                         "RNN_LAST_ACTIVATIONS, or RNN_ALL_ACTIVATIONS.")

    logits = tf.identity(logits, name="logits")
    if aux_logits is not None:
        aux_logits = tf.identity(aux_logits, name="aux_logits")

    # TODO(b/172564129): Remove from eval graph.
    if is_frozen:
        logits = tf.stop_gradient(logits)
        if aux_logits is not None:
            aux_logits = tf.stop_gradient(aux_logits)

    return TowerSpec(
        logits_spec=LogitsSpec(logits, logits_weight, aux_logits,
                               aux_logits_weight),
        architecture=[blocks.BlockType(block).name for block in architecture],
        layer_tensors=all_layer_tensors)