Пример #1
0
def process_episode(example_strings,
                    class_ids,
                    chunk_sizes,
                    image_size,
                    support_decoder=None,
                    query_decoder=None):
  """Processes an episode.

  This function:

  1) splits the batch of examples into "flush", "support", and "query" chunks,
  2) throws away the "flush" chunk,
  3) removes the padded dummy examples from the "support" and "query" chunks,
  4) extracts and processes images out of the example strings, and
  5) builds support and query targets (numbers from 0 to K-1 where K is the
     number of classes in the episode) from the class IDs.

  Args:
    example_strings: 1-D Tensor of dtype str, tf.train.Example protocol buffers.
    class_ids: 1-D Tensor of dtype int, class IDs (absolute wrt the original
      dataset).
    chunk_sizes: Tuple of ints representing the sizes the flush and additional
      chunks.
    image_size: int, desired image size used during decoding.
    support_decoder: Decoder class instance for support set.
    query_decoder: Decoder class instance for query set.

  Returns:
    support_images, support_labels, support_class_ids, query_images,
      query_labels, query_class_ids: Tensors, batches of images, labels, and
      (absolute) class IDs, for the support and query sets (respectively).
  """
  # TODO(goroshin): Replace with `support_decoder.log_summary(name='support')`.
  # TODO(goroshin): Eventually remove setting the image size here and pass it
  # to the ImageDecoder constructor instead.
  if isinstance(support_decoder, decoder.ImageDecoder):
    log_data_augmentation(support_decoder.data_augmentation, 'support')
    support_decoder.image_size = image_size
  if isinstance(query_decoder, decoder.ImageDecoder):
    log_data_augmentation(query_decoder.data_augmentation, 'query')
    query_decoder.image_size = image_size

  (support_strings, support_class_ids), (query_strings, query_class_ids) = \
      flush_and_chunk_episode(example_strings, class_ids, chunk_sizes)

  support_images = tf.map_fn(
      support_decoder, support_strings, dtype=tf.float32, back_prop=False)
  query_images = tf.map_fn(
      query_decoder, query_strings, dtype=tf.float32, back_prop=False)

  # Convert class IDs into labels in [0, num_ways).
  _, support_labels = tf.unique(support_class_ids)
  _, query_labels = tf.unique(query_class_ids)

  return (support_images, support_labels, support_class_ids, query_images,
          query_labels, query_class_ids)
Пример #2
0
def compute_train_class_proportions(episode, shots, dataset_spec):
    """Computes the proportion of each class' examples in the support set.

  Args:
    episode: An EpisodeDataset.
    shots: A 1D Tensor whose length is the `way' of the episode that stores the
      shots for this episode.
    dataset_spec: A DatasetSpecification.

  Returns:
    class_props: A 1D Tensor whose length is the `way' of the episode, storing
      for each class the proportion of its examples that are in the support set.
  """
    # Get the total number of examples of each class in the dataset.
    num_dataset_classes = len(dataset_spec.images_per_class)
    num_images_per_class = [
        dataset_spec.get_total_images_per_class(class_id)
        for class_id in range(num_dataset_classes)
    ]

    # Get the (absolute) class ID's that appear in the episode.
    class_ids, _ = tf.unique(episode.train_class_ids)  # [?, ]

    # Make sure that class_ids are valid indices of num_images_per_class. This is
    # important since tf.gather will fail silently and return zeros otherwise.
    num_classes = tf.shape(num_images_per_class)[0]
    check_valid_inds_op = tf.assert_less(class_ids, num_classes)
    with tf.control_dependencies([check_valid_inds_op]):
        # Get the total number of examples of each class that is in the episode.
        num_images_per_class = tf.gather(num_images_per_class,
                                         class_ids)  # [?, ]

    # Get the proportions of examples of each class that appear in the train set.
    class_props = tf.truediv(shots, num_images_per_class)
    return class_props
def compute_target_optimal_q(reward, gamma, next_actions, next_q_values,
                             next_states, terminals):
    """Builds an op used as a target for the Q-value.

  This algorithm corresponds to the method "OT" in
  Ie et al. https://arxiv.org/abs/1905.12767..

  Args:
    reward: [batch_size] tensor, the immediate reward.
    gamma: float, discount factor with the usual RL meaning.
    next_actions: [batch_size, slate_size] tensor, the next slate.
    next_q_values: [batch_size, num_of_documents] tensor, the q values of the
      documents in the next step.
    next_states: [batch_size, 1 + num_of_documents] tensor, the features for the
      user and the docuemnts in the next step.
    terminals: [batch_size] tensor, indicating if this is a terminal step.

  Returns:
    [batch_size] tensor, the target q values.
  """
    scores, score_no_click = _get_unnormalized_scores(next_states)

    # Obtain all possible slates given current docs in the candidate set.
    slate_size = next_actions.get_shape().as_list()[1]
    num_candidates = next_q_values.get_shape().as_list()[1]
    mesh_args = [list(range(num_candidates))] * slate_size
    slates = tf.stack(tf.meshgrid(*mesh_args), axis=-1)
    slates = tf.reshape(slates, shape=(-1, slate_size))
    # Filter slates that include duplicates to ensure each document is picked
    # at most once.
    unique_mask = tf.map_fn(
        lambda x: tf.equal(tf.size(input=x), tf.size(input=tf.unique(x)[0])),
        slates,
        dtype=tf.bool)
    # [num_of_slates, slate_size]
    slates = tf.boolean_mask(tensor=slates, mask=unique_mask)

    # [batch_size, num_of_slates, slate_size]
    next_q_values_slate = tf.gather(next_q_values, slates, axis=1)
    # [batch_size, num_of_slates, slate_size]
    scores_slate = tf.gather(scores, slates, axis=1)
    # [batch_size, num_of_slates]
    batch_size = next_states.get_shape().as_list()[0]
    score_no_click_slate = tf.reshape(
        tf.tile(score_no_click,
                tf.shape(input=slates)[:1]), [batch_size, -1])

    # [batch_size, num_of_slates]
    next_q_target_slate = tf.reduce_sum(
        input_tensor=next_q_values_slate * scores_slate,
        axis=2) / (tf.reduce_sum(input_tensor=scores_slate, axis=2) +
                   score_no_click_slate)

    next_q_target_max = tf.reduce_max(input_tensor=next_q_target_slate, axis=1)

    return reward + gamma * next_q_target_max * (
        1. - tf.cast(terminals, tf.float32))
Пример #4
0
def compute_episode_stats(episode):
    """Computes various episode stats: way, shots, and class IDs.

  Args:
    episode: An EpisodeDataset.

  Returns:
    way: An int constant tensor. The number of classes in the episode.
    shots: An int 1D tensor: The number of support examples per class.
    class_ids: An int 1D tensor: (absolute) class IDs.
  """
    # The train labels of the next episode.
    train_labels = episode.train_labels
    # Compute way.
    episode_classes, _ = tf.unique(train_labels)
    way = tf.size(episode_classes)
    # Compute shots.
    class_ids = tf.reshape(tf.range(way), [way, 1])
    class_labels = tf.reshape(train_labels, [1, -1])
    is_equal = tf.equal(class_labels, class_ids)
    shots = tf.reduce_sum(tf.cast(is_equal, tf.int32), axis=1)
    # Compute class_ids.
    class_ids, _ = tf.unique(episode.train_class_ids)
    return way, shots, class_ids
Пример #5
0
def proto_maml_fc_layer_init_fn(labels, embeddings, weights, biases,
                                prototype_multiplier):
    """Return a list of operations for reparameterized ProtoNet initialization."""

    # This is robust to classes missing from the training set, but assumes that
    # the last class is present.
    num_ways = tf.cast(
        tf.math.reduce_max(input_tensor=tf.unique(labels)[0]) + 1, tf.int32)

    # When there are no examples for a given class, we default its prototype to
    # zeros, per the implementation of `tf.math.unsorted_segment_mean`.
    prototypes = tf.math.unsorted_segment_mean(embeddings, labels, num_ways)

    # Scale the prototypes, which acts as a regularizer on the weights and biases.
    prototypes *= prototype_multiplier

    # logit = -<squared Euclidian distance to prototype>
    #       = -(x - p)^T.(x - p)
    #       = 2 x^T.p - p^T.p - x^T.x
    #       = x^T.w + b
    #         where w = 2p, b = -p^T.p
    output_weights = tf.transpose(a=2 * prototypes)
    output_biases = -tf.reduce_sum(input_tensor=prototypes * prototypes,
                                   axis=1)

    # We zero-pad to align with the original weights and biases.
    output_weights = tf.pad(tensor=output_weights,
                            paddings=[[0, 0],
                                      [
                                          0,
                                          tf.shape(input=weights)[1] -
                                          tf.shape(input=output_weights)[1]
                                      ]],
                            mode='CONSTANT',
                            constant_values=0)
    output_biases = tf.pad(tensor=output_biases,
                           paddings=[[
                               0,
                               tf.shape(input=biases)[0] -
                               tf.shape(input=output_biases)[0]
                           ]],
                           mode='CONSTANT',
                           constant_values=0)

    return [
        weights.assign(output_weights),
        biases.assign(output_biases),
    ]
def select_slate_optimal(slate_size, s_no_click, s, q):
    """Selects the slate using exhaustive search.

  This algorithm corresponds to the method "OS" in
  Ie et al. https://arxiv.org/abs/1905.12767.

  Args:
    slate_size: int, the size of the recommendation slate.
    s_no_click: float tensor, the score for not clicking any document.
    s: [num_of_documents] tensor, the scores for clicking documents.
    q: [num_of_documents] tensor, the predicted q values for documents.

  Returns:
    [slate_size] tensor, the selected slate.
  """

    num_candidates = s.shape.as_list()[0]

    # Obtain all possible slates given current docs in the candidate set.
    mesh_args = [list(range(num_candidates))] * slate_size
    slates = tf.stack(tf.meshgrid(*mesh_args), axis=-1)
    slates = tf.reshape(slates, shape=(-1, slate_size))

    # Filter slates that include duplicates to ensure each document is picked
    # at most once.
    unique_mask = tf.map_fn(
        lambda x: tf.equal(tf.size(input=x), tf.size(input=tf.unique(x)[0])),
        slates,
        dtype=tf.bool)
    slates = tf.boolean_mask(tensor=slates, mask=unique_mask)

    slate_q_values = tf.gather(s * q, slates)
    slate_scores = tf.gather(s, slates)
    slate_normalizer = tf.reduce_sum(input_tensor=slate_scores,
                                     axis=1) + s_no_click

    slate_q_values = slate_q_values / tf.expand_dims(slate_normalizer, 1)
    slate_sum_q_values = tf.reduce_sum(input_tensor=slate_q_values, axis=1)
    max_q_slate_index = tf.argmax(input=slate_sum_q_values)
    return tf.gather(slates, max_q_slate_index, axis=0)
Пример #7
0
def process_episode(example_strings,
                    class_ids,
                    chunk_sizes,
                    image_size,
                    support_data_augmentation=None,
                    query_data_augmentation=None):
  """Processes an episode.

  This function:

  1) splits the batch of examples into "flush", "support", and "query" chunks,
  2) throws away the "flush" chunk,
  3) removes the padded dummy examples from the "support" and "query" chunks,
     and
  4) extracts and processes images out of the example strings.
  5) builds support and query targets (numbers from 0 to K-1 where K is the
     number of classes in the episode) from the class IDs.

  Args:
    example_strings: 1-D Tensor of dtype str, tf.train.Example protocol buffers.
    class_ids: 1-D Tensor of dtype int, class IDs (absolute wrt the original
      dataset).
    chunk_sizes: Tuple of 3 ints representing the sizes of (resp.) the flush,
      support, and query chunks.
    image_size: int, desired image size used during decoding.
    support_data_augmentation: A DataAugmentation object with parameters for
      perturbing the support set images.
    query_data_augmentation: A DataAugmentation object with parameters for
      perturbing the query set images.

  Returns:
    support_images, support_labels, support_class_ids, query_images,
      query_labels, query_class_ids: Tensors, batches of images, labels, and
      (absolute) class IDs, for the support and query sets (respectively).
  """
  _log_data_augmentation(support_data_augmentation, 'support')
  _log_data_augmentation(query_data_augmentation, 'query')
  flush_chunk_size, support_chunk_size, _ = chunk_sizes
  support_start = flush_chunk_size
  query_start = support_start + support_chunk_size
  support_map_fn = functools.partial(
      process_example,
      image_size=image_size,
      data_augmentation=support_data_augmentation)
  query_map_fn = functools.partial(
      process_example,
      image_size=image_size,
      data_augmentation=query_data_augmentation)

  support_strings = example_strings[support_start:query_start]
  support_class_ids = class_ids[support_start:query_start]
  (support_strings,
   support_class_ids) = filter_dummy_examples(support_strings,
                                              support_class_ids)
  support_images = tf.map_fn(
      support_map_fn, support_strings, dtype=tf.float32, back_prop=False)

  query_strings = example_strings[query_start:]
  query_class_ids = class_ids[query_start:]
  (query_strings,
   query_class_ids) = filter_dummy_examples(query_strings, query_class_ids)
  query_images = tf.map_fn(
      query_map_fn, query_strings, dtype=tf.float32, back_prop=False)

  # Convert class IDs into labels in [0, num_ways).
  _, support_labels = tf.unique(support_class_ids)
  _, query_labels = tf.unique(query_class_ids)

  return (support_images, support_labels, support_class_ids, query_images,
          query_labels, query_class_ids)