def get_support_set_softmax(self, logits, class_ids): """Softmax normalize over the support set. Args: logits: [N_k, H*W, Q] dimensional tensor. class_ids: [N_k] tensor giving the support-set-id of each image. Returns: Softmax-ed x over the support set. softmax(x) = np.exp(x) / np.reduce_sum(np.exp(x), axis) """ max_logit = tf.reduce_max(logits, axis=1, keepdims=True) max_logit = tf.math.unsorted_segment_max(max_logit, class_ids, tf.reduce_max(class_ids) + 1) max_logit = tf.gather(max_logit, class_ids) logits_reduc = logits - max_logit exp_x = tf.exp(logits_reduc) sum_exp_x = tf.reduce_sum(exp_x, axis=1, keepdims=True) sum_exp_x = tf.math.unsorted_segment_sum(sum_exp_x, class_ids, tf.reduce_max(class_ids) + 1) log_sum_exp_x = tf.log(sum_exp_x) log_sum_exp_x = tf.gather(log_sum_exp_x, class_ids) norm_logits = logits_reduc - log_sum_exp_x softmax = tf.exp(norm_logits) return softmax
def randomly_crop_points(mesh_inputs, view_indices_2d_inputs, x_random_crop_size, y_random_crop_size, epsilon=1e-5): """Randomly crops points. Args: mesh_inputs: A dictionary containing input mesh (point) tensors. view_indices_2d_inputs: A dictionary containing input point to view correspondence tensors. x_random_crop_size: Size of the random crop in x dimension. If None, random crop will not take place on x dimension. y_random_crop_size: Size of the random crop in y dimension. If None, random crop will not take place on y dimension. epsilon: Epsilon (a very small value) used to add as a small margin to thresholds. """ if x_random_crop_size is None and y_random_crop_size is None: return points = mesh_inputs[standard_fields.InputDataFields.point_positions] num_points = tf.shape(points)[0] # Pick a random point if x_random_crop_size is not None or y_random_crop_size is not None: random_index = tf.random.uniform([], minval=0, maxval=num_points, dtype=tf.int32) center_x = points[random_index, 0] center_y = points[random_index, 1] points_x = points[:, 0] points_y = points[:, 1] min_x = tf.reduce_min(points_x) - epsilon max_x = tf.reduce_max(points_x) + epsilon min_y = tf.reduce_min(points_y) - epsilon max_y = tf.reduce_max(points_y) + epsilon if x_random_crop_size is not None: min_x = center_x - x_random_crop_size / 2.0 - epsilon max_x = center_x + x_random_crop_size / 2.0 + epsilon if y_random_crop_size is not None: min_y = center_y - y_random_crop_size / 2.0 - epsilon max_y = center_y + y_random_crop_size / 2.0 + epsilon x_mask = tf.logical_and(tf.greater(points_x, min_x), tf.less(points_x, max_x)) y_mask = tf.logical_and(tf.greater(points_y, min_y), tf.less(points_y, max_y)) points_mask = tf.logical_and(x_mask, y_mask) for key in sorted(mesh_inputs): mesh_inputs[key] = tf.boolean_mask(mesh_inputs[key], points_mask) for key in sorted(view_indices_2d_inputs): view_indices_2d_inputs[key] = tf.transpose( tf.boolean_mask( tf.transpose(view_indices_2d_inputs[key], [1, 0, 2]), points_mask), [1, 0, 2])
def _build_train_op(self): """Builds the training op for Rainbow. Returns: train_op: An op performing one step of training. """ replay_action_one_hot = tf.one_hot(self._replay.actions, self.num_actions, 1., 0., name='action_one_hot') replay_chosen_q = tf.reduce_sum(self._replay_qs * replay_action_one_hot, reduction_indices=1, name='replay_chosen_q') target = tf.stop_gradient(self._build_target_q_op()) loss = tf.losses.huber_loss(target, replay_chosen_q, reduction=tf.losses.Reduction.NONE) update_priorities_op = self._replay.tf_set_priority( self._replay.indices, tf.sqrt(loss + 1e-10)) target_priorities = self._replay.tf_get_priority(self._replay.indices) target_priorities = tf.math.add(target_priorities, 1e-10) target_priorities = 1.0 / tf.sqrt(target_priorities) target_priorities /= tf.reduce_max(target_priorities) weighted_loss = target_priorities * loss with tf.control_dependencies([update_priorities_op]): return self.optimizer.minimize( tf.reduce_mean(weighted_loss)), weighted_loss
def test_estimated_entropy(self, assume_reparametrization): logging.info("assume_reparametrization=%s" % assume_reparametrization) num_samples = 1000000 seed_stream = tfp.distributions.SeedStream( seed=1, salt='test_estimated_entropy') batch_shape = (2, ) loc = tf.random.normal(shape=batch_shape, seed=seed_stream()) scale = tf.abs(tf.random.normal(shape=batch_shape, seed=seed_stream())) with tf.GradientTape(persistent=True) as tape: tape.watch(scale) dist = tfp.distributions.Normal(loc=loc, scale=scale) analytic_entropy = dist.entropy() est_entropy, est_entropy_for_gradient = dist_utils.estimated_entropy( dist=dist, seed=seed_stream(), assume_reparametrization=assume_reparametrization, num_samples=num_samples) analytic_grad = tape.gradient(analytic_entropy, scale) est_grad = tape.gradient(est_entropy_for_gradient, scale) logging.info("scale=%s" % scale) logging.info("analytic_entropy=%s" % analytic_entropy) logging.info("estimated_entropy=%s" % est_entropy) self.assertArrayAlmostEqual(analytic_entropy, est_entropy, 5e-2) logging.info("analytic_entropy_grad=%s" % analytic_grad) logging.info("estimated_entropy_grad=%s" % est_grad) self.assertArrayAlmostEqual(analytic_grad, est_grad, 5e-2) if not assume_reparametrization: est_grad_wrong = tape.gradient(est_entropy, scale) logging.info("estimated_entropy_grad_wrong=%s", est_grad_wrong) self.assertLess(tf.reduce_max(tf.abs(est_grad_wrong)), 5e-2)
def _build_train_op(self): """Builds the training op for Rainbow. Returns: train_op: An op performing one step of training. """ target_distribution = tf.stop_gradient(self._build_target_distribution()) # size of indices: batch_size x 1. indices = tf.range(tf.shape(self._replay_logits)[0])[:, None] # size of reshaped_actions: batch_size x 2. reshaped_actions = tf.concat([indices, self._replay.actions[:, None]], 1) # For each element of the batch, fetch the logits for its selected action. chosen_action_logits = tf.gather_nd(self._replay_logits, reshaped_actions) loss = tf.nn.softmax_cross_entropy_with_logits( labels=target_distribution, logits=chosen_action_logits) optimizer = tf.train.AdamOptimizer( learning_rate=self.learning_rate, epsilon=self.optimizer_epsilon) update_priorities_op = self._replay.tf_set_priority( self._replay.indices, tf.sqrt(loss + 1e-10)) target_priorities = self._replay.tf_get_priority(self._replay.indices) target_priorities = tf.math.add(target_priorities, 1e-10) target_priorities = 1.0 / tf.sqrt(target_priorities) target_priorities /= tf.reduce_max(target_priorities) weighted_loss = target_priorities * loss with tf.control_dependencies([update_priorities_op]): return optimizer.minimize(tf.reduce_mean(weighted_loss)), weighted_loss
def compute_target_optimal_q(reward, gamma, next_actions, next_q_values, next_states, terminals): """Builds an op used as a target for the Q-value. This algorithm corresponds to the method "OT" in Ie et al. https://arxiv.org/abs/1905.12767.. Args: reward: [batch_size] tensor, the immediate reward. gamma: float, discount factor with the usual RL meaning. next_actions: [batch_size, slate_size] tensor, the next slate. next_q_values: [batch_size, num_of_documents] tensor, the q values of the documents in the next step. next_states: [batch_size, 1 + num_of_documents] tensor, the features for the user and the docuemnts in the next step. terminals: [batch_size] tensor, indicating if this is a terminal step. Returns: [batch_size] tensor, the target q values. """ scores, score_no_click = _get_unnormalized_scores(next_states) # Obtain all possible slates given current docs in the candidate set. slate_size = next_actions.get_shape().as_list()[1] num_candidates = next_q_values.get_shape().as_list()[1] mesh_args = [list(range(num_candidates))] * slate_size slates = tf.stack(tf.meshgrid(*mesh_args), axis=-1) slates = tf.reshape(slates, shape=(-1, slate_size)) # Filter slates that include duplicates to ensure each document is picked # at most once. unique_mask = tf.map_fn( lambda x: tf.equal(tf.size(input=x), tf.size(input=tf.unique(x)[0])), slates, dtype=tf.bool) # [num_of_slates, slate_size] slates = tf.boolean_mask(tensor=slates, mask=unique_mask) # [batch_size, num_of_slates, slate_size] next_q_values_slate = tf.gather(next_q_values, slates, axis=1) # [batch_size, num_of_slates, slate_size] scores_slate = tf.gather(scores, slates, axis=1) # [batch_size, num_of_slates] batch_size = next_states.get_shape().as_list()[0] score_no_click_slate = tf.reshape( tf.tile(score_no_click, tf.shape(input=slates)[:1]), [batch_size, -1]) # [batch_size, num_of_slates] next_q_target_slate = tf.reduce_sum( input_tensor=next_q_values_slate * scores_slate, axis=2) / (tf.reduce_sum(input_tensor=scores_slate, axis=2) + score_no_click_slate) next_q_target_max = tf.reduce_max(input_tensor=next_q_target_slate, axis=1) return reward + gamma * next_q_target_max * ( 1. - tf.cast(terminals, tf.float32))
def _variable_summaries(var): """Attach a lot of summaries to a Tensor (for TensorBoard visualization).""" with tf.name_scope('summaries'): mean = tf.reduce_mean(var) tf.summary.scalar('mean', mean) with tf.name_scope('stddev'): stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) tf.summary.scalar('stddev', stddev) tf.summary.scalar('max', tf.reduce_max(var)) tf.summary.scalar('min', tf.reduce_min(var)) tf.summary.histogram('histogram', var)
def _build_train_op(self): """Builds a training op. Returns: train_op: An op performing one step of training from replay data. """ replay_next_target_value = tf.reduce_max( self._replay_next_target_net_outputs.q_values, 1) replay_target_value = tf.reduce_max( self._replay_target_net_outputs.q_values, 1) replay_action_one_hot = tf.one_hot(self._replay.actions, self.num_actions, 1., 0., name='action_one_hot') replay_chosen_q = tf.reduce_sum(self._replay_net_outputs.q_values * replay_action_one_hot, axis=1, name='replay_chosen_q') replay_target_chosen_q = tf.reduce_sum( self._replay_target_net_outputs.q_values * replay_action_one_hot, axis=1, name='replay_chosen_q') augmented_rewards = self._replay.rewards - self.alpha * ( replay_target_value - replay_target_chosen_q) target = (augmented_rewards + self.cumulative_gamma * replay_next_target_value * (1. - tf.cast(self._replay.terminals, tf.float32))) target = tf.stop_gradient(target) loss = tf.losses.huber_loss(target, replay_chosen_q, reduction=tf.losses.Reduction.NONE) if self.summary_writer is not None: with tf.variable_scope('Losses'): tf.summary.scalar('HuberLoss', tf.reduce_mean(loss)) return self.optimizer.minimize(tf.reduce_mean(loss))
def summarize_stats(stats): """Summarize a dictionary of variables. Args: stats: a dictionary of {name: tensor} to compute stats over. """ for name, stat in stats.items(): mean = tf.reduce_mean(stat) tf.summary.scalar('mean_%s' % name, mean) tf.summary.scalar('max_%s' % name, tf.reduce_max(stat)) tf.summary.scalar('min_%s' % name, tf.reduce_min(stat)) std = tf.sqrt(tf.reduce_mean(tf.square(stat)) - tf.square(mean) + 1e-10) tf.summary.scalar('std_%s' % name, std) tf.summary.histogram(name, stat)
def _get_dist(self, query_queries, query_values, support_keys, support_values, labels): """Get distances between queries and query-aligned prototypes.""" # attended_values: [N_support, n_query, h_query, w_query, C] attended_values = self._attend(query_queries, support_keys, support_values, labels) # query_aligned_prototypes: [N_classes, n_query, h_query, w_query, C] query_aligned_prototypes = tf.math.unsorted_segment_sum( attended_values, labels, tf.reduce_max(labels) + 1) # (scaled) Euclidean distance shp = tf.shape(query_values) aligned_dist = tf.square(query_values[tf.newaxis, Ellipsis] - query_aligned_prototypes) return tf.reduce_sum(aligned_dist, [2, 3, 4]) / tf.cast( shp[-3] * shp[-2], aligned_dist.dtype)
def _build_target_q_op(self): """Build an op to be used as a target for the Q-value. Returns: target_q_op: An op calculating the target Q-value. """ # Get the max q_value across the actions dimension. replay_next_qt_max = tf.reduce_max( self._replay_next_qt + self._replay.next_legal_actions, 1) # Calculate the sample Bellman update. # Q_t = R_t + \gamma^N * Q'_t+1 # where, # Q'_t+1 is \argmax_a Q(S_t+1, a) # (or) 0 if S_t is a terminal state, # and # N is the update horizon (by default, N=1). return self._replay.rewards + self.cumulative_gamma * replay_next_qt_max * ( 1. - tf.cast(self._replay.terminals, tf.float32))
def _box_classification_loss_unbatched(inputs_1, outputs_1, is_intermediate, is_balanced, mine_hard_negatives, hard_negative_score_threshold): """Loss function for input and outputs of batch size 1.""" valid_mask = _get_voxels_valid_mask(inputs_1=inputs_1) if is_intermediate: logits = outputs_1[standard_fields.DetectionResultFields. intermediate_object_semantic_voxels] else: logits = outputs_1[ standard_fields.DetectionResultFields.object_semantic_voxels] num_classes = logits.get_shape().as_list()[-1] if num_classes is None: raise ValueError('Number of classes is unknown.') logits = tf.boolean_mask(tf.reshape(logits, [-1, num_classes]), valid_mask) labels = tf.boolean_mask( tf.reshape( inputs_1[standard_fields.InputDataFields.object_class_voxels], [-1, 1]), valid_mask) if mine_hard_negatives or is_balanced: instances = tf.boolean_mask( tf.reshape( inputs_1[ standard_fields.InputDataFields.object_instance_id_voxels], [-1]), valid_mask) params = {} if mine_hard_negatives: negative_scores = tf.reshape(tf.nn.softmax(logits)[:, 0], [-1]) hard_negative_mask = tf.logical_and( tf.less(negative_scores, hard_negative_score_threshold), tf.equal(tf.reshape(labels, [-1]), 0)) hard_negative_labels = tf.boolean_mask(labels, hard_negative_mask) hard_negative_logits = tf.boolean_mask(logits, hard_negative_mask) hard_negative_instances = tf.boolean_mask( tf.ones_like(instances) * (tf.reduce_max(instances) + 1), hard_negative_mask) logits = tf.concat([logits, hard_negative_logits], axis=0) instances = tf.concat([instances, hard_negative_instances], axis=0) labels = tf.concat([labels, hard_negative_labels], axis=0) if is_balanced: weights = loss_utils.get_balanced_loss_weights_multiclass( labels=tf.expand_dims(instances, axis=1)) params['weights'] = weights return classification_loss_fn(logits=logits, labels=labels, **params)
def fwd_fn(query_queries_fwd, query_values_fwd, support_keys_fwd, support_values_fwd, labels_fwd): """CrossTransformer forward, using a while loop to save memory.""" initial = (0, tf.zeros([tf.reduce_max(labels) + 1, zero_dim], dtype=query_queries_fwd.dtype)) def loop_body(idx, dist): dist_new = self._get_dist(query_queries_fwd[idx:idx + 1], query_values_fwd[idx:idx + 1], support_keys_fwd, support_values_fwd, labels_fwd) dist = tf.concat([dist, dist_new], axis=1) return (idx + 1, dist) _, res = tf.while_loop( lambda x, _: x < tf.shape(query_queries_fwd)[0], loop_body, initial, parallel_iterations=1) return res
def assertArrayAlmostEqual(self, x, y, eps): self.assertLess(tf.reduce_max(tf.abs(x - y)), eps)
def classification_loss_using_mask_iou_func(embeddings, logits, instance_ids, class_labels, num_samples, valid_mask=None, max_instance_id=None, similarity_strategy='dotproduct', is_balanced=True): """Classification loss using mask iou. Args: embeddings: A tf.float32 tensor of size [batch_size, n, f]. logits: A tf.float32 tensor of size [batch_size, n, num_classes]. It is assumed that background is class 0. instance_ids: A tf.int32 tensor of size [batch_size, n]. class_labels: A tf.int32 tensor of size [batch_size, n]. It is assumed that the background voxels are assigned to class 0. num_samples: An int determining the number of samples. valid_mask: A tf.bool tensor of size [batch_size, n] that is True when an element is valid and False if it needs to be ignored. By default the value is None which means it is not applied. max_instance_id: If set, instance ids larger than that value will be ignored. If not set, it will be computed from instance_ids tensor. similarity_strategy: Defines the method for computing similarity between embedding vectors. Possible values are 'dotproduct' and 'distance'. is_balanced: If True, the per-voxel losses are re-weighted to have equal total weight for foreground vs. background voxels. Returns: A tf.float32 scalar loss tensor. """ batch_size = embeddings.get_shape().as_list()[0] if batch_size is None: raise ValueError('Unknown batch size at graph construction time.') if max_instance_id is None: max_instance_id = tf.reduce_max(instance_ids) class_labels = tf.reshape(class_labels, [batch_size, -1, 1]) sampled_embeddings, sampled_instance_ids, sampled_indices = ( sampling_utils.balanced_sample(features=embeddings, instance_ids=instance_ids, num_samples=num_samples, valid_mask=valid_mask, max_instance_id=max_instance_id)) losses = [] for i in range(batch_size): embeddings_i = embeddings[i, :, :] instance_ids_i = instance_ids[i, :] class_labels_i = class_labels[i, :, :] logits_i = logits[i, :] sampled_embeddings_i = sampled_embeddings[i, :, :] sampled_instance_ids_i = sampled_instance_ids[i, :] sampled_indices_i = sampled_indices[i, :] sampled_class_labels_i = tf.gather(class_labels_i, sampled_indices_i) sampled_logits_i = tf.gather(logits_i, sampled_indices_i) if valid_mask is not None: valid_mask_i = valid_mask[i] embeddings_i = tf.boolean_mask(embeddings_i, valid_mask_i) instance_ids_i = tf.boolean_mask(instance_ids_i, valid_mask_i) loss_i = classification_loss_using_mask_iou_func_unbatched( embeddings=embeddings_i, instance_ids=instance_ids_i, sampled_embeddings=sampled_embeddings_i, sampled_instance_ids=sampled_instance_ids_i, sampled_class_labels=sampled_class_labels_i, sampled_logits=sampled_logits_i, similarity_strategy=similarity_strategy, is_balanced=is_balanced) losses.append(loss_i) return tf.math.reduce_mean(tf.stack(losses))
def train(hparams, num_epoch, tuning): log_dir = './results/' test_batch_size = 8 # Load dataset training_set, valid_set = make_dataset(BATCH_SIZE=hparams['HP_BS'], file_name='train_tf_record', split=True) test_set = make_dataset(BATCH_SIZE=test_batch_size, file_name='test_tf_record', split=False) class_names = ['NRDR', 'RDR'] # Model model = ResNet() # set optimizer optimizer = tf.keras.optimizers.Adam(learning_rate=hparams['HP_LR']) # set metrics train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy() valid_accuracy = tf.keras.metrics.Accuracy() valid_con_mat = ConfusionMatrix(num_class=2) test_accuracy = tf.keras.metrics.Accuracy() test_con_mat = ConfusionMatrix(num_class=2) # Save Checkpoint if not tuning: ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optimizer, net=model) manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=5) # Set up summary writers current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") tb_log_dir = log_dir + current_time + '/train' summary_writer = tf.summary.create_file_writer(tb_log_dir) # Restore Checkpoint if not tuning: ckpt.restore(manager.latest_checkpoint) if manager.latest_checkpoint: logging.info('Restored from {}'.format(manager.latest_checkpoint)) else: logging.info('Initializing from scratch.') @tf.function def train_step(train_img, train_label): # Optimize the model loss_value, grads = grad(model, train_img, train_label) optimizer.apply_gradients(zip(grads, model.trainable_variables)) train_pred, _ = model(train_img) train_label = tf.expand_dims(train_label, axis=1) train_accuracy.update_state(train_label, train_pred) for epoch in range(num_epoch): begin = time() # Training loop for train_img, train_label, train_name in training_set: train_img = data_augmentation(train_img) train_step(train_img, train_label) with summary_writer.as_default(): tf.summary.scalar('Train Accuracy', train_accuracy.result(), step=epoch) for valid_img, valid_label, _ in valid_set: valid_img = tf.cast(valid_img, tf.float32) valid_img = valid_img / 255.0 valid_pred, _ = model(valid_img, training=False) valid_pred = tf.cast(tf.argmax(valid_pred, axis=1), dtype=tf.int64) valid_con_mat.update_state(valid_label, valid_pred) valid_accuracy.update_state(valid_label, valid_pred) # Log the confusion matrix as an image summary cm_valid = valid_con_mat.result() figure = plot_confusion_matrix(cm_valid, class_names=class_names) cm_valid_image = plot_to_image(figure) with summary_writer.as_default(): tf.summary.scalar('Valid Accuracy', valid_accuracy.result(), step=epoch) tf.summary.image('Valid ConfusionMatrix', cm_valid_image, step=epoch) end = time() logging.info( "Epoch {:d} Training Accuracy: {:.3%} Validation Accuracy: {:.3%} Time:{:.5}s" .format(epoch + 1, train_accuracy.result(), valid_accuracy.result(), (end - begin))) train_accuracy.reset_states() valid_accuracy.reset_states() valid_con_mat.reset_states() if not tuning: if int(ckpt.step) % 5 == 0: save_path = manager.save() logging.info('Saved checkpoint for epoch {}: {}'.format( int(ckpt.step), save_path)) ckpt.step.assign_add(1) for test_img, test_label, _ in test_set: test_img = tf.cast(test_img, tf.float32) test_img = test_img / 255.0 test_pred, _ = model(test_img, training=False) test_pred = tf.cast(tf.argmax(test_pred, axis=1), dtype=tf.int64) test_accuracy.update_state(test_label, test_pred) test_con_mat.update_state(test_label, test_pred) cm_test = test_con_mat.result() # Log the confusion matrix as an image summary figure = plot_confusion_matrix(cm_test, class_names=class_names) cm_test_image = plot_to_image(figure) with summary_writer.as_default(): tf.summary.scalar('Test Accuracy', test_accuracy.result(), step=epoch) tf.summary.image('Test ConfusionMatrix', cm_test_image, step=epoch) logging.info("Trained finished. Final Accuracy in test set: {:.3%}".format( test_accuracy.result())) # Visualization if not tuning: for vis_img, vis_label, vis_name in test_set: vis_label = vis_label[0] vis_name = vis_name[0] vis_img = tf.cast(vis_img[0], tf.float32) vis_img = tf.expand_dims(vis_img, axis=0) vis_img = vis_img / 255.0 with tf.GradientTape() as tape: vis_pred, conv_output = model(vis_img, training=False) pred_label = tf.argmax(vis_pred, axis=-1) vis_pred = tf.reduce_max(vis_pred, axis=-1) grad_1 = tape.gradient(vis_pred, conv_output) weight = tf.reduce_mean(grad_1, axis=[1, 2]) / grad_1.shape[1] act_map0 = tf.nn.relu( tf.reduce_sum(weight * conv_output, axis=-1)) act_map0 = tf.squeeze(tf.image.resize(tf.expand_dims(act_map0, axis=-1), (256, 256), antialias=True), axis=-1) plot_map(vis_img, act_map0, vis_pred, pred_label, vis_label, vis_name) break return test_accuracy.result()
def npair_loss_func(embeddings, instance_ids, num_samples, valid_mask=None, max_instance_id=None, similarity_strategy='dotproduct', loss_strategy='softmax'): """N-pair metric learning loss for learning feature embeddings. Args: embeddings: A tf.float32 tensor of size [batch_size, n, f]. instance_ids: A tf.int32 tensor of size [batch_size, n]. num_samples: An int determinig the number of samples. valid_mask: A tf.bool tensor of size [batch_size, n] that is True when an element is valid and False if it needs to be ignored. By default the value is None which means it is not applied. max_instance_id: If set, instance ids larger than that value will be ignored. If not set, it will be computed from instance_ids tensor. similarity_strategy: Defines the method for computing similarity between embedding vectors. Possible values are 'dotproduct' and 'distance'. loss_strategy: Defines the type of loss including 'softmax' or 'sigmoid'. Returns: A tf.float32 scalar loss tensor. """ batch_size = embeddings.get_shape().as_list()[0] if batch_size is None: raise ValueError('Unknown batch size at graph construction time.') if max_instance_id is None: max_instance_id = tf.reduce_max(instance_ids) sampled_embeddings, sampled_instance_ids, _ = sampling_utils.balanced_sample( features=embeddings, instance_ids=instance_ids, num_samples=num_samples, valid_mask=valid_mask, max_instance_id=max_instance_id) losses = [] for i in range(batch_size): sampled_instance_ids_i = sampled_instance_ids[i, :] sampled_embeddings_i = sampled_embeddings[i, :, :] min_ids_i = tf.math.reduce_min(sampled_instance_ids_i) max_ids_i = tf.math.reduce_max(sampled_instance_ids_i) target_i = tf.one_hot( sampled_instance_ids_i, depth=(max_instance_id + 1), dtype=tf.float32) # pylint: disable=cell-var-from-loop def npair_loss_i(): return metric_learning_losses.npair_loss( embedding=sampled_embeddings_i, target=target_i, similarity_strategy=similarity_strategy, loss_strategy=loss_strategy) # pylint: enable=cell-var-from-loop loss_i = tf.cond( max_ids_i > min_ids_i, npair_loss_i, lambda: tf.constant(0.0, dtype=tf.float32)) losses.append(loss_i) return tf.math.reduce_mean(losses)