def process_episode(example_strings, class_ids, chunk_sizes, image_size, support_decoder=None, query_decoder=None): """Processes an episode. This function: 1) splits the batch of examples into "flush", "support", and "query" chunks, 2) throws away the "flush" chunk, 3) removes the padded dummy examples from the "support" and "query" chunks, 4) extracts and processes images out of the example strings, and 5) builds support and query targets (numbers from 0 to K-1 where K is the number of classes in the episode) from the class IDs. Args: example_strings: 1-D Tensor of dtype str, tf.train.Example protocol buffers. class_ids: 1-D Tensor of dtype int, class IDs (absolute wrt the original dataset). chunk_sizes: Tuple of ints representing the sizes the flush and additional chunks. image_size: int, desired image size used during decoding. support_decoder: Decoder class instance for support set. query_decoder: Decoder class instance for query set. Returns: support_images, support_labels, support_class_ids, query_images, query_labels, query_class_ids: Tensors, batches of images, labels, and (absolute) class IDs, for the support and query sets (respectively). """ # TODO(goroshin): Replace with `support_decoder.log_summary(name='support')`. # TODO(goroshin): Eventually remove setting the image size here and pass it # to the ImageDecoder constructor instead. if isinstance(support_decoder, decoder.ImageDecoder): log_data_augmentation(support_decoder.data_augmentation, 'support') support_decoder.image_size = image_size if isinstance(query_decoder, decoder.ImageDecoder): log_data_augmentation(query_decoder.data_augmentation, 'query') query_decoder.image_size = image_size (support_strings, support_class_ids), (query_strings, query_class_ids) = \ flush_and_chunk_episode(example_strings, class_ids, chunk_sizes) support_images = tf.map_fn( support_decoder, support_strings, dtype=tf.float32, back_prop=False) query_images = tf.map_fn( query_decoder, query_strings, dtype=tf.float32, back_prop=False) # Convert class IDs into labels in [0, num_ways). _, support_labels = tf.unique(support_class_ids) _, query_labels = tf.unique(query_class_ids) return (support_images, support_labels, support_class_ids, query_images, query_labels, query_class_ids)
def compute_train_class_proportions(episode, shots, dataset_spec): """Computes the proportion of each class' examples in the support set. Args: episode: An EpisodeDataset. shots: A 1D Tensor whose length is the `way' of the episode that stores the shots for this episode. dataset_spec: A DatasetSpecification. Returns: class_props: A 1D Tensor whose length is the `way' of the episode, storing for each class the proportion of its examples that are in the support set. """ # Get the total number of examples of each class in the dataset. num_dataset_classes = len(dataset_spec.images_per_class) num_images_per_class = [ dataset_spec.get_total_images_per_class(class_id) for class_id in range(num_dataset_classes) ] # Get the (absolute) class ID's that appear in the episode. class_ids, _ = tf.unique(episode.train_class_ids) # [?, ] # Make sure that class_ids are valid indices of num_images_per_class. This is # important since tf.gather will fail silently and return zeros otherwise. num_classes = tf.shape(num_images_per_class)[0] check_valid_inds_op = tf.assert_less(class_ids, num_classes) with tf.control_dependencies([check_valid_inds_op]): # Get the total number of examples of each class that is in the episode. num_images_per_class = tf.gather(num_images_per_class, class_ids) # [?, ] # Get the proportions of examples of each class that appear in the train set. class_props = tf.truediv(shots, num_images_per_class) return class_props
def compute_target_optimal_q(reward, gamma, next_actions, next_q_values, next_states, terminals): """Builds an op used as a target for the Q-value. This algorithm corresponds to the method "OT" in Ie et al. https://arxiv.org/abs/1905.12767.. Args: reward: [batch_size] tensor, the immediate reward. gamma: float, discount factor with the usual RL meaning. next_actions: [batch_size, slate_size] tensor, the next slate. next_q_values: [batch_size, num_of_documents] tensor, the q values of the documents in the next step. next_states: [batch_size, 1 + num_of_documents] tensor, the features for the user and the docuemnts in the next step. terminals: [batch_size] tensor, indicating if this is a terminal step. Returns: [batch_size] tensor, the target q values. """ scores, score_no_click = _get_unnormalized_scores(next_states) # Obtain all possible slates given current docs in the candidate set. slate_size = next_actions.get_shape().as_list()[1] num_candidates = next_q_values.get_shape().as_list()[1] mesh_args = [list(range(num_candidates))] * slate_size slates = tf.stack(tf.meshgrid(*mesh_args), axis=-1) slates = tf.reshape(slates, shape=(-1, slate_size)) # Filter slates that include duplicates to ensure each document is picked # at most once. unique_mask = tf.map_fn( lambda x: tf.equal(tf.size(input=x), tf.size(input=tf.unique(x)[0])), slates, dtype=tf.bool) # [num_of_slates, slate_size] slates = tf.boolean_mask(tensor=slates, mask=unique_mask) # [batch_size, num_of_slates, slate_size] next_q_values_slate = tf.gather(next_q_values, slates, axis=1) # [batch_size, num_of_slates, slate_size] scores_slate = tf.gather(scores, slates, axis=1) # [batch_size, num_of_slates] batch_size = next_states.get_shape().as_list()[0] score_no_click_slate = tf.reshape( tf.tile(score_no_click, tf.shape(input=slates)[:1]), [batch_size, -1]) # [batch_size, num_of_slates] next_q_target_slate = tf.reduce_sum( input_tensor=next_q_values_slate * scores_slate, axis=2) / (tf.reduce_sum(input_tensor=scores_slate, axis=2) + score_no_click_slate) next_q_target_max = tf.reduce_max(input_tensor=next_q_target_slate, axis=1) return reward + gamma * next_q_target_max * ( 1. - tf.cast(terminals, tf.float32))
def compute_episode_stats(episode): """Computes various episode stats: way, shots, and class IDs. Args: episode: An EpisodeDataset. Returns: way: An int constant tensor. The number of classes in the episode. shots: An int 1D tensor: The number of support examples per class. class_ids: An int 1D tensor: (absolute) class IDs. """ # The train labels of the next episode. train_labels = episode.train_labels # Compute way. episode_classes, _ = tf.unique(train_labels) way = tf.size(episode_classes) # Compute shots. class_ids = tf.reshape(tf.range(way), [way, 1]) class_labels = tf.reshape(train_labels, [1, -1]) is_equal = tf.equal(class_labels, class_ids) shots = tf.reduce_sum(tf.cast(is_equal, tf.int32), axis=1) # Compute class_ids. class_ids, _ = tf.unique(episode.train_class_ids) return way, shots, class_ids
def proto_maml_fc_layer_init_fn(labels, embeddings, weights, biases, prototype_multiplier): """Return a list of operations for reparameterized ProtoNet initialization.""" # This is robust to classes missing from the training set, but assumes that # the last class is present. num_ways = tf.cast( tf.math.reduce_max(input_tensor=tf.unique(labels)[0]) + 1, tf.int32) # When there are no examples for a given class, we default its prototype to # zeros, per the implementation of `tf.math.unsorted_segment_mean`. prototypes = tf.math.unsorted_segment_mean(embeddings, labels, num_ways) # Scale the prototypes, which acts as a regularizer on the weights and biases. prototypes *= prototype_multiplier # logit = -<squared Euclidian distance to prototype> # = -(x - p)^T.(x - p) # = 2 x^T.p - p^T.p - x^T.x # = x^T.w + b # where w = 2p, b = -p^T.p output_weights = tf.transpose(a=2 * prototypes) output_biases = -tf.reduce_sum(input_tensor=prototypes * prototypes, axis=1) # We zero-pad to align with the original weights and biases. output_weights = tf.pad(tensor=output_weights, paddings=[[0, 0], [ 0, tf.shape(input=weights)[1] - tf.shape(input=output_weights)[1] ]], mode='CONSTANT', constant_values=0) output_biases = tf.pad(tensor=output_biases, paddings=[[ 0, tf.shape(input=biases)[0] - tf.shape(input=output_biases)[0] ]], mode='CONSTANT', constant_values=0) return [ weights.assign(output_weights), biases.assign(output_biases), ]
def select_slate_optimal(slate_size, s_no_click, s, q): """Selects the slate using exhaustive search. This algorithm corresponds to the method "OS" in Ie et al. https://arxiv.org/abs/1905.12767. Args: slate_size: int, the size of the recommendation slate. s_no_click: float tensor, the score for not clicking any document. s: [num_of_documents] tensor, the scores for clicking documents. q: [num_of_documents] tensor, the predicted q values for documents. Returns: [slate_size] tensor, the selected slate. """ num_candidates = s.shape.as_list()[0] # Obtain all possible slates given current docs in the candidate set. mesh_args = [list(range(num_candidates))] * slate_size slates = tf.stack(tf.meshgrid(*mesh_args), axis=-1) slates = tf.reshape(slates, shape=(-1, slate_size)) # Filter slates that include duplicates to ensure each document is picked # at most once. unique_mask = tf.map_fn( lambda x: tf.equal(tf.size(input=x), tf.size(input=tf.unique(x)[0])), slates, dtype=tf.bool) slates = tf.boolean_mask(tensor=slates, mask=unique_mask) slate_q_values = tf.gather(s * q, slates) slate_scores = tf.gather(s, slates) slate_normalizer = tf.reduce_sum(input_tensor=slate_scores, axis=1) + s_no_click slate_q_values = slate_q_values / tf.expand_dims(slate_normalizer, 1) slate_sum_q_values = tf.reduce_sum(input_tensor=slate_q_values, axis=1) max_q_slate_index = tf.argmax(input=slate_sum_q_values) return tf.gather(slates, max_q_slate_index, axis=0)
def process_episode(example_strings, class_ids, chunk_sizes, image_size, support_data_augmentation=None, query_data_augmentation=None): """Processes an episode. This function: 1) splits the batch of examples into "flush", "support", and "query" chunks, 2) throws away the "flush" chunk, 3) removes the padded dummy examples from the "support" and "query" chunks, and 4) extracts and processes images out of the example strings. 5) builds support and query targets (numbers from 0 to K-1 where K is the number of classes in the episode) from the class IDs. Args: example_strings: 1-D Tensor of dtype str, tf.train.Example protocol buffers. class_ids: 1-D Tensor of dtype int, class IDs (absolute wrt the original dataset). chunk_sizes: Tuple of 3 ints representing the sizes of (resp.) the flush, support, and query chunks. image_size: int, desired image size used during decoding. support_data_augmentation: A DataAugmentation object with parameters for perturbing the support set images. query_data_augmentation: A DataAugmentation object with parameters for perturbing the query set images. Returns: support_images, support_labels, support_class_ids, query_images, query_labels, query_class_ids: Tensors, batches of images, labels, and (absolute) class IDs, for the support and query sets (respectively). """ _log_data_augmentation(support_data_augmentation, 'support') _log_data_augmentation(query_data_augmentation, 'query') flush_chunk_size, support_chunk_size, _ = chunk_sizes support_start = flush_chunk_size query_start = support_start + support_chunk_size support_map_fn = functools.partial( process_example, image_size=image_size, data_augmentation=support_data_augmentation) query_map_fn = functools.partial( process_example, image_size=image_size, data_augmentation=query_data_augmentation) support_strings = example_strings[support_start:query_start] support_class_ids = class_ids[support_start:query_start] (support_strings, support_class_ids) = filter_dummy_examples(support_strings, support_class_ids) support_images = tf.map_fn( support_map_fn, support_strings, dtype=tf.float32, back_prop=False) query_strings = example_strings[query_start:] query_class_ids = class_ids[query_start:] (query_strings, query_class_ids) = filter_dummy_examples(query_strings, query_class_ids) query_images = tf.map_fn( query_map_fn, query_strings, dtype=tf.float32, back_prop=False) # Convert class IDs into labels in [0, num_ways). _, support_labels = tf.unique(support_class_ids) _, query_labels = tf.unique(query_class_ids) return (support_images, support_labels, support_class_ids, query_images, query_labels, query_class_ids)