def tpu_cross_replica_concat(tensor, tpu_context=None): """Reduce a concatenation of the `tensor` across TPU cores. Args: tensor: tensor to concatenate. tpu_context: A `TPUContext`. If not set, CPU execution is assumed. Returns: Tensor of the same rank as `tensor` with first dimension `num_replicas` times larger. """ if tpu_context is None or tpu_context.num_replicas <= 1: return tensor num_replicas = tpu_context.num_replicas with tf.name_scope("tpu_cross_replica_concat"): # This creates a tensor that is like the input tensor but has an added # replica dimension as the outermost dimension. On each replica it will # contain the local values and zeros for all other values that need to be # fetched from other replicas. ext_tensor = tf.scatter_nd( indices=[[xla.replica_id()]], updates=[tensor], shape=[num_replicas] + tensor.shape.as_list(), ) # As every value is only present on one replica and 0 in all others, adding # them all together will result in the full tensor on all replicas. ext_tensor = tf.tpu.cross_replica_sum(ext_tensor) # Flatten the replica dimension. # The first dimension size will be: tensor.shape[0] * num_replicas # Using [-1] trick to support also scalar input. return tf.reshape(ext_tensor, [-1] + ext_tensor.shape.as_list()[2:])
def cross_replica_concat(tensor, num_replicas, name=None): """Reduce a concatenation of the `tensor` across tpu cores. Branched from //audio/ears/nnfp/tensorflow/tpu_ops.py Args: tensor: tensor to concatenate. num_replicas: Number of TPU cores. name: A name for the op. Returns: Tensor of the same rank as `tensor` with first dimension `num_replicas` times larger. """ replica_id = xla.replica_id() with tf.compat.v1.name_scope(name, 'tpu_cross_replica_concat'): # This creates a tensor that is like the input tensor but has an added # replica dimension as the outermost dimension. On each replica it will # contain the local values and zeros for all other values that need to be # fetched from other replicas. ext_tensor = tf.scatter_nd(indices=[[replica_id]], updates=[tensor], shape=[num_replicas] + tensor.shape.as_list()) # As every value is only present on one replica and 0 in all others, adding # them all together will result in the full tensor on all replicas. ext_tensor = tf.compat.v1.tpu.cross_replica_sum(ext_tensor) # Flatten the replica dimension. # The first dimension size will be: tensor.shape[0] * num_replicas # Using [-1] trick to support also scalar input. return tf.reshape(ext_tensor, [-1] + ext_tensor.shape.as_list()[2:])
def tpu_cross_replica_concat(tensor, num_replicas): """Reduce a concatenation of the `tensor` across TPU cores. Args: tensor: tensor to concatenate. num_replicas: number of TPU device replicas. Returns: Tensor of the same rank as `tensor` with first dimension `num_replicas` times larger. """ with tf.name_scope('tpu_cross_replica_concat'): # This creates a tensor that is like the input tensor but has an added # replica dimension as the outermost dimension. On each replica it will # contain the local values and zeros for all other values that need to be # fetched from other replicas. ext_tensor = tf.scatter_nd(indices=[[xla.replica_id()]], updates=[tensor], shape=[num_replicas] + tensor.shape.as_list()) # As every value is only present on one replica and 0 in all others, adding # them all together will result in the full tensor on all replicas. replica_context = tf.distribute.get_replica_context() ext_tensor = replica_context.all_reduce(tf.distribute.ReduceOp.SUM, ext_tensor) # Flatten the replica dimension. # The first dimension size will be: tensor.shape[0] * num_replicas # Using [-1] trick to support also scalar input. return tf.reshape(ext_tensor, [-1] + ext_tensor.shape.as_list()[2:])
def attractive_repulsive_loss(hidden, hidden_norm=True, temperature=1.0, tpu_context=None, weights=1.0): """Compute bottom up attractive and repulsive loss (contrastive loss) Args: hidden: hidden vector (`Tensor`) of shape (bsz, dim). hidden_norm: whether or not to use normalization on the hidden vector. temperature: a `floating` number for temperature scaling. tpu_context: context information for tpu. weights: a weighting number or vector. Returns: A loss scalar. The logits for contrastive prediction task. The labels for contrastive prediction task. """ # Get (normalized) hidden1 and hidden2. if hidden_norm: hidden = tf.math.l2_normalize(hidden, -1) hidden1, hidden2 = tf.split(hidden, 2, 0) batch_size = tf.shape(hidden1)[0] # Gather hidden1/hidden2 across replicas and create local labels. if tpu_context is not None: hidden1_large = tpu_cross_replica_concat(hidden1, tpu_context) hidden2_large = tpu_cross_replica_concat(hidden2, tpu_context) enlarged_batch_size = tf.shape(hidden1_large)[0] # TODO(iamtingchen): more elegant way to convert u32 to s32 for replica_id. replica_id = tf.cast(tf.cast(xla.replica_id(), tf.uint32), tf.int32) labels_idx = tf.range(batch_size) + replica_id * batch_size labels = tf.one_hot(labels_idx, enlarged_batch_size * 2) masks = tf.one_hot(labels_idx, enlarged_batch_size) else: hidden1_large = hidden1 hidden2_large = hidden2 labels = tf.one_hot(tf.range(batch_size), batch_size * 2) masks = tf.one_hot(tf.range(batch_size), batch_size) logits_aa = tf.matmul(hidden1, hidden1_large, transpose_b=True) / temperature logits_aa = logits_aa - masks * LARGE_NUM logits_bb = tf.matmul(hidden2, hidden2_large, transpose_b=True) / temperature logits_bb = logits_bb - masks * LARGE_NUM logits_ab = tf.matmul(hidden1, hidden2_large, transpose_b=True) / temperature logits_ba = tf.matmul(hidden2, hidden1_large, transpose_b=True) / temperature loss_a = tf.losses.softmax_cross_entropy( labels, tf.concat([logits_ab, logits_aa], 1), weights=weights) loss_b = tf.losses.softmax_cross_entropy( labels, tf.concat([logits_ba, logits_bb], 1), weights=weights) loss = loss_a + loss_b return loss, logits_ab, labels
def add_contrastive_loss(hidden, hidden_norm=True, temperature=1.0, tpu_context=None, weights=1.0): """Compute loss for model. Args: hidden: hidden vector (`Tensor`) of shape (bsz, dim). hidden_norm: whether or not to use normalization on the hidden vector. temperature: a `floating` number for temperature scaling. tpu_context: context information for tpu. weights: a weighting number or vector. Returns: A loss scalar. The logits for contrastive prediction task. The labels for contrastive prediction task. """ # Get (normalized) hidden1 and hidden2. if hidden_norm: hidden = tf.math.l2_normalize(hidden, -1) hidden1, hidden2 = tf.split(hidden, 2, 0) batch_size = tf.shape(hidden1)[0] # Gather hidden1/hidden2 across replicas and create local labels. if tpu_context is not None: hidden1_large = tpu_cross_replica_concat(hidden1, tpu_context) hidden2_large = tpu_cross_replica_concat(hidden2, tpu_context) enlarged_batch_size = tf.shape(hidden1_large)[0] # TODO(iamtingchen): more elegant way to convert u32 to s32 for replica_id. replica_id = tf.cast(tf.cast(xla.replica_id(), tf.uint32), tf.int32) labels_idx = tf.range(batch_size) + replica_id * batch_size labels = tf.one_hot(labels_idx, enlarged_batch_size * 2) masks = tf.one_hot(labels_idx, enlarged_batch_size) else: hidden1_large = hidden1 hidden2_large = hidden2 labels = tf.one_hot(tf.range(batch_size), batch_size * 2) masks = tf.one_hot(tf.range(batch_size), batch_size) logits_aa = tf.matmul(hidden1, hidden1_large, transpose_b=True) / temperature logits_aa = logits_aa - masks * LARGE_NUM logits_bb = tf.matmul(hidden2, hidden2_large, transpose_b=True) / temperature logits_bb = logits_bb - masks * LARGE_NUM logits_ab = tf.matmul(hidden1, hidden2_large, transpose_b=True) / temperature logits_ba = tf.matmul(hidden2, hidden1_large, transpose_b=True) / temperature logits_a = tf.concat([logits_ab, logits_aa], 1) logits_b = tf.concat([logits_ba, logits_bb], 1) if FLAGS.loss_func != 'NT-Xent': logits_positive = tf.diag_part(logits_ab) temp_positive = tf.tile(tf.expand_dims(logits_positive, -1), [1, logits_a.shape[1]]) masks_a = tf.cast(tf.greater_equal(logits_a, temp_positive - 1e-5), tf.float32) masks_b = tf.cast(tf.greater_equal(logits_b, temp_positive - 1e-5), tf.float32) logits_a = logits_a - masks_a * LARGE_NUM logits_b = logits_b - masks_b * LARGE_NUM logits_negative_a = tf.reduce_max(logits_a, axis=1) logits_negative_b = tf.reduce_max(logits_b, axis=1) #print(logits_negative_a, logits_negative_b) if FLAGS.loss_func == 'NT-Logistic': loss_a = tf.reduce_mean( tf.log(1 + tf.exp(-logits_positive)) + tf.log(1 + tf.exp(logits_negative_a))) loss_b = tf.reduce_mean( tf.log(1 + tf.exp(-logits_positive)) + tf.log(1 + tf.exp(logits_negative_b))) tf.losses.add_loss(loss_a + loss_b) #print(loss_a, loss_b) return loss_a + loss_b, logits_ab, labels else: loss_a = tf.reduce_mean( tf.maximum(logits_negative_a - logits_positive + MARGIN, 0)) loss_b = tf.reduce_mean( tf.maximum(logits_negative_b - logits_positive + MARGIN, 0)) tf.losses.add_loss(loss_a + loss_b) return loss_a + loss_b, logits_ab, labels loss_a = tf.losses.softmax_cross_entropy(labels, logits_a, weights=weights) loss_b = tf.losses.softmax_cross_entropy(labels, logits_b, weights=weights) #print(loss_a, loss_b) loss = loss_a + loss_b return loss, logits_ab, labels
def contrastive_loss(hidden, num_replicas, normalize_hidden, temperature, model, weight_decay): """Computes contrastive loss. Args: hidden: embedding of video clips after projection head. num_replicas: number of distributed replicas. normalize_hidden: whether or not to l2 normalize the hidden vector. temperature: temperature in the InfoNCE contrastive loss. model: keras model for calculating weight decay. weight_decay: weight decay parameter. Returns: A loss scalar. The logits for contrastive prediction task. The labels for contrastive prediction task. """ large_num = 1e9 hidden1, hidden2 = tf.split(hidden, num_or_size_splits=2, axis=0) if normalize_hidden: hidden1 = tf.math.l2_normalize(hidden1, -1) hidden2 = tf.math.l2_normalize(hidden2, -1) batch_size = tf.shape(hidden1)[0] if num_replicas == 1: # This is the local version hidden1_large = hidden1 hidden2_large = hidden2 labels = tf.one_hot(tf.range(batch_size), batch_size * 2) masks = tf.one_hot(tf.range(batch_size), batch_size) else: # This is the cross-tpu version. hidden1_large = tpu_cross_replica_concat(hidden1, num_replicas) hidden2_large = tpu_cross_replica_concat(hidden2, num_replicas) enlarged_batch_size = tf.shape(hidden1_large)[0] replica_id = tf.cast(tf.cast(xla.replica_id(), tf.uint32), tf.int32) labels_idx = tf.range(batch_size) + replica_id * batch_size labels = tf.one_hot(labels_idx, enlarged_batch_size * 2) masks = tf.one_hot(labels_idx, enlarged_batch_size) logits_aa = tf.matmul(hidden1, hidden1_large, transpose_b=True) / temperature logits_aa = logits_aa - tf.cast(masks, logits_aa.dtype) * large_num logits_bb = tf.matmul(hidden2, hidden2_large, transpose_b=True) / temperature logits_bb = logits_bb - tf.cast(masks, logits_bb.dtype) * large_num logits_ab = tf.matmul(hidden1, hidden2_large, transpose_b=True) / temperature logits_ba = tf.matmul(hidden2, hidden1_large, transpose_b=True) / temperature loss_a = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits( labels, tf.concat([logits_ab, logits_aa], 1))) loss_b = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits( labels, tf.concat([logits_ba, logits_bb], 1))) loss = loss_a + loss_b l2_loss = weight_decay * tf.add_n([ tf.nn.l2_loss(v) for v in model.trainable_variables if 'kernel' in v.name ]) total_loss = loss + tf.cast(l2_loss, loss.dtype) contrast_prob = tf.nn.softmax(logits_ab) contrast_entropy = -tf.reduce_mean( tf.reduce_sum(contrast_prob * tf.math.log(contrast_prob + 1e-8), -1)) contrast_acc = tf.equal(tf.argmax(labels, 1), tf.argmax(logits_ab, axis=1)) contrast_acc = tf.reduce_mean(tf.cast(contrast_acc, tf.float32)) return { 'total_loss': total_loss, 'contrastive_loss': loss, 'reg_loss': l2_loss, 'contrast_acc': contrast_acc, 'contrast_entropy': contrast_entropy, }
def tpu_id(): # TODO(iamtingchen): more elegant way to convert u32 to s32 for replica_id. replica_id = tf.cast(tf.cast(xla.replica_id(), tf.uint32), tf.int32) return replica_id
def add_contrastive_loss_multi_aug(hidden, hidden_norm=True, temperature=1.0, tpu_context=None, weights=1.0): """Compute loss for model. Args: hidden: hidden vector (`Tensor`) of shape (bsz, dim). hidden_norm: whether or not to use normalization on the hidden vector. temperature: a `floating` number for temperature scaling. tpu_context: context information for tpu. weights: a weighting number or vector. Returns: A loss scalar. The logits for contrastive prediction task. The labels for contrastive prediction task. """ # Get (normalized) hidden1 and hidden2. if hidden_norm: hidden = tf.math.l2_normalize(hidden, -1) hiddens = tf.split(hidden, FLAGS.num_transforms, 0) batch_size = tf.shape(hiddens[0])[0] if tpu_context is None: raise NotImplementedError('GPU not supported') hiddens_large = [tpu_cross_replica_concat(hidden, tpu_context)\ for hidden in hiddens] enlarged_batch_size = tf.shape(hiddens_large[0])[0] replica_id = tf.cast(tf.cast(xla.replica_id(), tf.uint32), tf.int32) labels_idx = tf.range(batch_size) + replica_id * batch_size labels = tf.one_hot(labels_idx, enlarged_batch_size * FLAGS.num_transforms) masks = tf.one_hot(labels_idx, enlarged_batch_size) if FLAGS.adjust_temp: _aug_cloud_density = [] for _idx1 in range(FLAGS.num_transforms): for _idx2 in range(_idx1 + 1, FLAGS.num_transforms): _dot_pd_result = tf.matmul(hiddens[_idx1], hiddens[_idx2], transpose_b=True) / temperature if FLAGS.exp_to_adjust_temp: _dot_pd_result = tf.exp(_dot_pd_result) _dot_pd_result = tf.linalg.diag_part(_dot_pd_result) _aug_cloud_density.append( tf.expand_dims(_dot_pd_result, axis=1)) _aug_cloud_density = tf.concat(_aug_cloud_density, axis=1) _aug_cloud_density = tf.reduce_mean(_aug_cloud_density, axis=1) _large_aug_cloud_density = tpu_cross_replica_concat( _aug_cloud_density, tpu_context) mean_aug_cloud_density = tf.math.reduce_mean(_large_aug_cloud_density) std_aug_cloud_density = tf.math.reduce_std(_large_aug_cloud_density) if not FLAGS.rn_avg_for_temp: aug_cloud_density = \ (_aug_cloud_density - mean_aug_cloud_density) / std_aug_cloud_density else: run_mean = tf.get_variable(name="run_mean", shape=(), dtype=tf.float32, trainable=False, initializer=tf.zeros_initializer()) run_std = tf.get_variable(name="run_std", shape=(), dtype=tf.float32, trainable=False, initializer=tf.ones_initializer()) new_mean = run_mean * 0.99 + mean_aug_cloud_density * 0.01 new_std = run_std * 0.99 + std_aug_cloud_density * 0.01 with tf.control_dependencies( [run_mean.assign(new_mean), run_std.assign(new_std)]): aug_cloud_density = \ (_aug_cloud_density - new_mean) / new_std all_logits = [[] for _ in range(FLAGS.num_transforms)] for _idx1 in range(FLAGS.num_transforms): for _idx2 in range(FLAGS.num_transforms): _logits = tf.matmul(hiddens[_idx1], hiddens_large[_idx2], transpose_b=True) / temperature if FLAGS.adjust_temp: if not FLAGS.adjust_temp_params: new_temp = temperature + FLAGS.adjust_temp_std * aug_cloud_density new_temp = tf.clip_by_value(new_temp, 0.05, 0.15) else: neg_std, pos_std, neg_bnd, pos_bnd \ = [float(_each_param) \ for _each_param in FLAGS.adjust_temp_params.split(',')] aug_pos_flag = tf.cast(aug_cloud_density > 0, tf.float32) new_temp = temperature \ + neg_std * aug_cloud_density * (1-aug_pos_flag) \ + pos_std * aug_cloud_density * aug_pos_flag new_temp = tf.clip_by_value(new_temp, neg_bnd, pos_bnd) new_temp = tf.expand_dims(new_temp, axis=1) _logits = tf.matmul(hiddens[_idx1], hiddens_large[_idx2], transpose_b=True) / new_temp all_logits[_idx1].append(_logits) loss = 0 for _idx1 in range(FLAGS.num_transforms): for _idx2 in range(FLAGS.num_transforms): if _idx2 == _idx1: continue _logits = [all_logits[_idx1][_idx2]] for _idx3 in range(1, FLAGS.num_transforms): new_idx2 = (_idx2 + _idx3) % FLAGS.num_transforms _logits.append(all_logits[_idx1][new_idx2] - masks * LARGE_NUM) _logits = tf.concat(_logits, 1) _loss = tf.losses.softmax_cross_entropy(labels, _logits, weights=weights) loss += _loss return loss, all_logits[0][1], labels
def make_distributed_tensor(strategy, tensors): stacked = tf.stack(tensors, axis=0) fn = tf.function(lambda t: t[xla.replica_id()]) return strategy.run(fn, args=(stacked, ))
def add_contrastive_loss(hidden, hidden_norm=True, temperature=1.0, tpu_context=None, weights=1.0, loss_type=None, flags=None): """Compute loss for model. Args: hidden: hidden vector (`Tensor`) of shape (bsz, dim). hidden_norm: whether or not to use normalization on the hidden vector. temperature: a `floating` number for temperature scaling. tpu_context: context information for tpu. weights: a weighting number or vector. Returns: A loss scalar. The logits for contrastive prediction task. The labels for contrastive prediction task. """ # Get (normalized) hidden1 and hidden2. if hidden_norm: hidden = tf.math.l2_normalize(hidden, -1) hidden1, hidden2 = tf.split(hidden, 2, 0) batch_size = tf.shape(hidden1)[0] # Gather hidden1/hidden2 across replicas and create local labels. if tpu_context is not None: hidden1_large = tpu_cross_replica_concat(hidden1, tpu_context) hidden2_large = tpu_cross_replica_concat(hidden2, tpu_context) enlarged_batch_size = tf.shape(hidden1_large)[0] replica_id = tf.cast(tf.cast(xla.replica_id(), tf.uint32), tf.int32) labels_idx = tf.range(batch_size) + replica_id * batch_size labels = tf.one_hot(labels_idx, enlarged_batch_size * 2) masks = tf.one_hot(labels_idx, enlarged_batch_size) else: hidden1_large = hidden1 hidden2_large = hidden2 labels = tf.one_hot(tf.range(batch_size), batch_size * 2) masks = tf.one_hot(tf.range(batch_size), batch_size) # check WPC if loss_type.lower() == "wpc": assert flags.gradient_penalty_weight != 0.0 else: assert flags.gradient_penalty_weight == 0.0 logits_aa = tf.matmul(hidden1, hidden1_large, transpose_b=True) / temperature logits_bb = tf.matmul(hidden2, hidden2_large, transpose_b=True) / temperature # aa and bb diagonals are not positive samples; positive samples are ab abd ba diagnoals # thus we want to mask aa and bb diagonals out if loss_type.lower() == "nce" or loss_type.lower( ) == "dv" or loss_type.lower() == "wpc": # NCE loss: minus big number to create cloes to 0 values in softmax print(loss_type) logits_aa = logits_aa - masks * LARGE_NUM logits_bb = logits_bb - masks * LARGE_NUM else: # otherwise just mask out using 0 logits_aa = logits_aa * (1 - masks) logits_bb = logits_bb * (1 - masks) logits_ab = tf.matmul(hidden1, hidden2_large, transpose_b=True) / temperature logits_ba = tf.matmul(hidden2, hidden1_large, transpose_b=True) / temperature ############################################################################# ### Different losses: nce, chi, js and nwj ### Pos_scores: positive samples, i.e. joint distribution terms ### neg_scores: negative samples, i.e. marginal distribution terms ############################################################################# if loss_type.lower() == "nce": loss_a = tf.losses.softmax_cross_entropy( labels, tf.concat([logits_ab, logits_aa], 1), weights=weights) loss_b = tf.losses.softmax_cross_entropy( labels, tf.concat([logits_ba, logits_bb], 1), weights=weights) elif loss_type.lower() == "chi": ## Chi squared loss in general form alpha = flags.alpha beta = flags.beta gamma = flags.gamma joint_a = labels * tf.concat([logits_ab, logits_aa], 1) joint_b = labels * tf.concat([logits_ba, logits_bb], 1) # non-correlated views marg_a = (1. - labels) * tf.concat([logits_ab, logits_aa], 1) marg_b = (1. - labels) * tf.concat([logits_ba, logits_bb], 1) batch_size = tf.cast(batch_size, tf.float32) joint = 0.5*(tf.reduce_sum(joint_a - 0.5 * beta * joint_a**2) / batch_size)\ + 0.5*(tf.reduce_sum(joint_b - 0.5 * beta * joint_b**2) / batch_size) # non-correlated views marg = 0.5*(tf.reduce_sum(alpha * marg_a + 0.5 * gamma * marg_a**2) / (2*batch_size*(batch_size-1.)))\ + 0.5*(tf.reduce_sum(alpha * marg_b + 0.5 * gamma * marg_b**2) / (2*batch_size*(batch_size-1.))) loss = -1. * (joint - marg) tf.losses.add_loss(loss) return loss, logits_ab, labels elif loss_type.lower() == "js": # Jensen Shannon def js(logits_concat, labels, scope=None): lbls = math_ops.cast(labels, logits_concat.dtype) """SHOULD I ADD STOP GRADIENT?""" bs = math_ops.cast(batch_size, logits_concat.dtype) pos_scores = tf.reduce_sum( lbls * (-tf.math.softplus(-logits_concat))) / bs neg_scores = tf.reduce_sum( (1 - lbls) * tf.math.softplus(logits_concat)) / ( (2 * bs - 1) * bs) return -(pos_scores - neg_scores) loss_a = 0.5 * js(tf.concat([logits_ab, logits_aa], 1), labels) loss_b = 0.5 * js(tf.concat([logits_ba, logits_bb], 1), labels) tf.losses.add_loss(loss_a) tf.losses.add_loss(loss_b) elif loss_type.lower() == "nwj": def nwj(logits_concat, labels, scope=None): lbls = math_ops.cast(labels, logits_concat.dtype) """SHOULD I ADD STOP GRADIENT?""" bs = math_ops.cast(batch_size, logits_concat.dtype) pos_scores = tf.reduce_sum(lbls * logits_concat) / bs neg_scores = tf.reduce_sum( (1 - lbls) * tf.math.exp(logits_concat - 1)) / ( (2 * bs - 1) * bs) return -(pos_scores - neg_scores) loss_a = 0.5 * nwj(tf.concat([logits_ab, logits_aa], 1), labels) loss_b = 0.5 * nwj(tf.concat([logits_ba, logits_bb], 1), labels) tf.losses.add_loss(loss_a) tf.losses.add_loss(loss_b) elif loss_type.lower() == "dv": # Donsker and Varadhan def dv(logits_concat, labels, scope=None): lbls = math_ops.cast(labels, logits_concat.dtype) """SHOULD I ADD STOP GRADIENT?""" bs = math_ops.cast(batch_size, logits_concat.dtype) pos_scores = tf.reduce_sum(lbls * logits_concat) / bs neg_scores = tf.math.reduce_logsumexp( (1 - lbls) * logits_concat) - tf.math.log((2 * bs - 1) * bs) return -(pos_scores - neg_scores) loss_a = 0.5 * dv(tf.concat([logits_ab, logits_aa], 1), labels) loss_b = 0.5 * dv(tf.concat([logits_ba, logits_bb], 1), labels) tf.losses.add_loss(loss_a) tf.losses.add_loss(loss_b) elif loss_type.lower() == "wpc": # Wasserstein Dependency Measure (i.e. Wasserstein Predictive Coding) # Adding soon pass # operation performed in model.py else: raise ValueError( "Loss not implemented yet; only support {nce, chi, js, nwj}") loss = loss_a + loss_b return loss, logits_ab, labels
def td_attractive_repulsive_loss(reconstruction, target, hidden_norm=True, power=2, temperature=1.0, tpu_context=None, weights=1.0): """Compute top down attractive and repulsive loss base on pixel-wise error. Args: hidden: hidden vector (`Tensor`) of shape (bsz, dim). hidden_norm: whether or not to use normalization on the hidden vector. temperature: a `floating` number for temperature scaling. tpu_context: context information for tpu. weights: a weighting number or vector. Returns: A loss scalar. The logits for contrastive prediction task. The labels for contrastive prediction task. """ # Get (normalized) hidden1 and hidden2. if hidden_norm: reconstruction = tf.math.l2_normalize(reconstruction, -1) target = tf.math.l2_normalize(target, -1) batch_size = target.get_shape().as_list()[0] rec_size = reconstruction.get_shape().as_list()[0] if batch_size != rec_size: reconstruction1, reconstruction2 = tf.split(reconstruction, 2, 0) # batch_size = tf.shape(reconstruction1)[0] # Gather hidden1/hidden2 across replicas and create local labels. if tpu_context is not None: reconstruction1_large = tpu_cross_replica_concat(reconstruction1, tpu_context) reconstruction2_large = tpu_cross_replica_concat(reconstruction2, tpu_context) target_large = tpu_cross_replica_concat(target, tpu_context) enlarged_batch_size = tf.shape(reconstruction1_large)[0] # TODO(iamtingchen): more elegant way to convert u32 to s32 for replica_id. replica_id = tf.cast(tf.cast(xla.replica_id(), tf.uint32), tf.int32) labels_idx = tf.range(batch_size) + replica_id * batch_size labels = tf.one_hot(labels_idx, enlarged_batch_size * 3) masks = tf.one_hot(labels_idx, enlarged_batch_size) else: reconstruction1_large = reconstruction1 reconstruction2_large = reconstruction2 labels = tf.one_hot(tf.range(batch_size), batch_size * 3) masks = tf.one_hot(tf.range(batch_size), batch_size) # target = tf.reshape(target, [batch_size, -1]) reconstruction1 = tf.reshape(reconstruction1, [batch_size, -1]) reconstruction2 = tf.reshape(reconstruction2, [batch_size, -1]) target_large = tf.reshape(target_large, [enlarged_batch_size, -1]) reconstruction1_large = tf.reshape(reconstruction1_large, [enlarged_batch_size, -1]) reconstruction2_large = tf.reshape(reconstruction2_large, [enlarged_batch_size, -1]) logits_at = tf.matmul(reconstruction1, target_large, transpose_b=True) / temperature logits_bt = tf.matmul(reconstruction2, target_large, transpose_b=True) / temperature logits_aa = tf.matmul(reconstruction1, reconstruction1_large, transpose_b=True) / temperature logits_aa = logits_aa - masks * LARGE_NUM logits_bb = tf.matmul(reconstruction2, reconstruction2_large, transpose_b=True) / temperature logits_bb = logits_bb - masks * LARGE_NUM logits_ab = tf.matmul(reconstruction1, reconstruction2_large, transpose_b=True) / temperature logits_ba = tf.matmul(reconstruction2, reconstruction1_large, transpose_b=True) / temperature loss_a = tf.losses.softmax_cross_entropy( labels, tf.concat([logits_at, logits_aa, logits_ab], 1), weights=weights) loss_b = tf.losses.softmax_cross_entropy( labels, tf.concat([logits_bt, logits_ba, logits_bb], 1), weights=weights) loss = loss_a + loss_b else: # Gather hidden1/hidden2 across replicas and create local labels. if tpu_context is not None: reconstruction_large = tpu_cross_replica_concat(reconstruction, tpu_context) target_large = tpu_cross_replica_concat(target, tpu_context) enlarged_batch_size = tf.shape(reconstruction_large)[0] # TODO(iamtingchen): more elegant way to convert u32 to s32 for replica_id. replica_id = tf.cast(tf.cast(xla.replica_id(), tf.uint32), tf.int32) labels_idx = tf.range(batch_size) + replica_id * batch_size labels = tf.one_hot(labels_idx, enlarged_batch_size * 2) masks = tf.one_hot(labels_idx, enlarged_batch_size) else: reconstruction_large = reconstruction labels = tf.one_hot(tf.range(batch_size), batch_size * 2) masks = tf.one_hot(tf.range(batch_size), batch_size) reconstruction = tf.reshape(reconstruction, [batch_size, -1]) target_large = tf.reshape(target_large, [enlarged_batch_size, -1]) reconstruction_large = tf.reshape(reconstruction_large, [enlarged_batch_size, -1]) logits_at = tf.matmul(reconstruction, target_large, transpose_b=True) / temperature logits_aa = tf.matmul(reconstruction, reconstruction_large, transpose_b=True) / temperature logits_aa = logits_aa - masks * LARGE_NUM loss = tf.losses.softmax_cross_entropy( labels, tf.concat([logits_at, logits_aa], 1), weights=weights) #tf.concat([logits_at, logits_aa], 1) return loss, logits_at, labels
def add_contrastive_loss(hidden, hidden_norm=True, temperature=1.0, tpu_context=None, weights=1.0): """Compute loss for model. Args: hidden: hidden vector (`Tensor`) of shape (2 * bsz, dim). hidden_norm: whether or not to use normalization on the hidden vector. temperature: a `floating` number for temperature scaling. tpu_context: context information for tpu. weights: a weighting number or vector. Returns: A loss scalar. The logits for contrastive prediction task. The labels for contrastive prediction task. """ # Get (normalized) hidden1 and hidden2. if hidden_norm: hidden = tf.math.l2_normalize(hidden, -1) hidden1, hidden2 = tf.split( hidden, 2, 0 ) # splits hidden in half along 0 axis (batch size axis?), but should be duplicating hidden?? batch_size = tf.shape( hidden1 )[0] # maybe one batch from dataloader = bs/2 images + bs/2 transformed images ?? # we need to change how hidden1 and hidden2 are calculated so that they are fed through different base models # Gather hidden1/hidden2 across replicas and create local labels. if tpu_context is not None: hidden1_large = tpu_cross_replica_concat(hidden1, tpu_context) hidden2_large = tpu_cross_replica_concat(hidden2, tpu_context) enlarged_batch_size = tf.shape(hidden1_large)[0] # TODO(iamtingchen): more elegant way to convert u32 to s32 for replica_id. replica_id = tf.cast(tf.cast(xla.replica_id(), tf.uint32), tf.int32) labels_idx = tf.range(batch_size) + replica_id * batch_size labels = tf.one_hot(labels_idx, enlarged_batch_size * 2) masks = tf.one_hot(labels_idx, enlarged_batch_size) else: hidden1_large = hidden1 hidden2_large = hidden2 labels = tf.one_hot(tf.range(batch_size), batch_size * 2) masks = tf.one_hot(tf.range(batch_size), batch_size) logits_aa = tf.matmul(hidden1, hidden1_large, transpose_b=True) / temperature logits_aa = logits_aa - masks * LARGE_NUM logits_bb = tf.matmul(hidden2, hidden2_large, transpose_b=True) / temperature logits_bb = logits_bb - masks * LARGE_NUM logits_ab = tf.matmul(hidden1, hidden2_large, transpose_b=True) / temperature logits_ba = tf.matmul(hidden2, hidden1_large, transpose_b=True) / temperature loss_a = tf.losses.softmax_cross_entropy(labels, tf.concat([logits_ab, logits_aa], 1), weights=weights) loss_b = tf.losses.softmax_cross_entropy(labels, tf.concat([logits_ba, logits_bb], 1), weights=weights) loss = loss_a + loss_b return loss, logits_ab, labels