예제 #1
0
def tpu_cross_replica_concat(tensor, tpu_context=None):
    """Reduce a concatenation of the `tensor` across TPU cores.

    Args:
      tensor: tensor to concatenate.
      tpu_context: A `TPUContext`. If not set, CPU execution is assumed.

    Returns:
      Tensor of the same rank as `tensor` with first dimension `num_replicas`
      times larger.
    """
    if tpu_context is None or tpu_context.num_replicas <= 1:
        return tensor

    num_replicas = tpu_context.num_replicas

    with tf.name_scope("tpu_cross_replica_concat"):
        # This creates a tensor that is like the input tensor but has an added
        # replica dimension as the outermost dimension. On each replica it will
        # contain the local values and zeros for all other values that need to be
        # fetched from other replicas.
        ext_tensor = tf.scatter_nd(
            indices=[[xla.replica_id()]],
            updates=[tensor],
            shape=[num_replicas] + tensor.shape.as_list(),
        )

        # As every value is only present on one replica and 0 in all others, adding
        # them all together will result in the full tensor on all replicas.
        ext_tensor = tf.tpu.cross_replica_sum(ext_tensor)

        # Flatten the replica dimension.
        # The first dimension size will be: tensor.shape[0] * num_replicas
        # Using [-1] trick to support also scalar input.
        return tf.reshape(ext_tensor, [-1] + ext_tensor.shape.as_list()[2:])
예제 #2
0
def cross_replica_concat(tensor, num_replicas, name=None):
    """Reduce a concatenation of the `tensor` across tpu cores.

  Branched from //audio/ears/nnfp/tensorflow/tpu_ops.py

  Args:
    tensor: tensor to concatenate.
    num_replicas: Number of TPU cores.
    name: A name for the op.

  Returns:
    Tensor of the same rank as `tensor` with first dimension `num_replicas`
    times larger.
  """
    replica_id = xla.replica_id()

    with tf.compat.v1.name_scope(name, 'tpu_cross_replica_concat'):
        # This creates a tensor that is like the input tensor but has an added
        # replica dimension as the outermost dimension. On each replica it will
        # contain the local values and zeros for all other values that need to be
        # fetched from other replicas.
        ext_tensor = tf.scatter_nd(indices=[[replica_id]],
                                   updates=[tensor],
                                   shape=[num_replicas] +
                                   tensor.shape.as_list())

        # As every value is only present on one replica and 0 in all others, adding
        # them all together will result in the full tensor on all replicas.
        ext_tensor = tf.compat.v1.tpu.cross_replica_sum(ext_tensor)

        # Flatten the replica dimension.
        # The first dimension size will be: tensor.shape[0] * num_replicas
        # Using [-1] trick to support also scalar input.
        return tf.reshape(ext_tensor, [-1] + ext_tensor.shape.as_list()[2:])
예제 #3
0
def tpu_cross_replica_concat(tensor, num_replicas):
    """Reduce a concatenation of the `tensor` across TPU cores.

  Args:
    tensor: tensor to concatenate.
    num_replicas: number of TPU device replicas.

  Returns:
    Tensor of the same rank as `tensor` with first dimension `num_replicas`
    times larger.
  """
    with tf.name_scope('tpu_cross_replica_concat'):
        # This creates a tensor that is like the input tensor but has an added
        # replica dimension as the outermost dimension. On each replica it will
        # contain the local values and zeros for all other values that need to be
        # fetched from other replicas.
        ext_tensor = tf.scatter_nd(indices=[[xla.replica_id()]],
                                   updates=[tensor],
                                   shape=[num_replicas] +
                                   tensor.shape.as_list())

        # As every value is only present on one replica and 0 in all others, adding
        # them all together will result in the full tensor on all replicas.
        replica_context = tf.distribute.get_replica_context()
        ext_tensor = replica_context.all_reduce(tf.distribute.ReduceOp.SUM,
                                                ext_tensor)

        # Flatten the replica dimension.
        # The first dimension size will be: tensor.shape[0] * num_replicas
        # Using [-1] trick to support also scalar input.
        return tf.reshape(ext_tensor, [-1] + ext_tensor.shape.as_list()[2:])
예제 #4
0
def attractive_repulsive_loss(hidden,
                         hidden_norm=True,
                         temperature=1.0,
                         tpu_context=None,
                         weights=1.0):
  """Compute bottom up attractive and repulsive loss (contrastive loss) 

  Args:
    hidden: hidden vector (`Tensor`) of shape (bsz, dim).
    hidden_norm: whether or not to use normalization on the hidden vector.
    temperature: a `floating` number for temperature scaling.
    tpu_context: context information for tpu.
    weights: a weighting number or vector.

  Returns:
    A loss scalar.
    The logits for contrastive prediction task.
    The labels for contrastive prediction task.
  """
  # Get (normalized) hidden1 and hidden2.
  if hidden_norm:
    hidden = tf.math.l2_normalize(hidden, -1)
  hidden1, hidden2 = tf.split(hidden, 2, 0)
  batch_size = tf.shape(hidden1)[0]

  # Gather hidden1/hidden2 across replicas and create local labels.
  if tpu_context is not None:
    hidden1_large = tpu_cross_replica_concat(hidden1, tpu_context)
    hidden2_large = tpu_cross_replica_concat(hidden2, tpu_context)
    enlarged_batch_size = tf.shape(hidden1_large)[0]
    # TODO(iamtingchen): more elegant way to convert u32 to s32 for replica_id.
    replica_id = tf.cast(tf.cast(xla.replica_id(), tf.uint32), tf.int32)
    labels_idx = tf.range(batch_size) + replica_id * batch_size
    labels = tf.one_hot(labels_idx, enlarged_batch_size * 2)
    masks = tf.one_hot(labels_idx, enlarged_batch_size)
  else:
    hidden1_large = hidden1
    hidden2_large = hidden2
    labels = tf.one_hot(tf.range(batch_size), batch_size * 2)
    masks = tf.one_hot(tf.range(batch_size), batch_size)

  logits_aa = tf.matmul(hidden1, hidden1_large, transpose_b=True) / temperature
  logits_aa = logits_aa - masks * LARGE_NUM
  logits_bb = tf.matmul(hidden2, hidden2_large, transpose_b=True) / temperature
  logits_bb = logits_bb - masks * LARGE_NUM
  logits_ab = tf.matmul(hidden1, hidden2_large, transpose_b=True) / temperature
  logits_ba = tf.matmul(hidden2, hidden1_large, transpose_b=True) / temperature

  loss_a = tf.losses.softmax_cross_entropy(
      labels, tf.concat([logits_ab, logits_aa], 1), weights=weights)
  loss_b = tf.losses.softmax_cross_entropy(
      labels, tf.concat([logits_ba, logits_bb], 1), weights=weights)
  loss = loss_a + loss_b

  return loss, logits_ab, labels
예제 #5
0
def add_contrastive_loss(hidden,
                         hidden_norm=True,
                         temperature=1.0,
                         tpu_context=None,
                         weights=1.0):
    """Compute loss for model.

  Args:
    hidden: hidden vector (`Tensor`) of shape (bsz, dim).
    hidden_norm: whether or not to use normalization on the hidden vector.
    temperature: a `floating` number for temperature scaling.
    tpu_context: context information for tpu.
    weights: a weighting number or vector.

  Returns:
    A loss scalar.
    The logits for contrastive prediction task.
    The labels for contrastive prediction task.
  """
    # Get (normalized) hidden1 and hidden2.
    if hidden_norm:
        hidden = tf.math.l2_normalize(hidden, -1)
    hidden1, hidden2 = tf.split(hidden, 2, 0)
    batch_size = tf.shape(hidden1)[0]

    # Gather hidden1/hidden2 across replicas and create local labels.
    if tpu_context is not None:
        hidden1_large = tpu_cross_replica_concat(hidden1, tpu_context)
        hidden2_large = tpu_cross_replica_concat(hidden2, tpu_context)
        enlarged_batch_size = tf.shape(hidden1_large)[0]
        # TODO(iamtingchen): more elegant way to convert u32 to s32 for replica_id.
        replica_id = tf.cast(tf.cast(xla.replica_id(), tf.uint32), tf.int32)
        labels_idx = tf.range(batch_size) + replica_id * batch_size
        labels = tf.one_hot(labels_idx, enlarged_batch_size * 2)
        masks = tf.one_hot(labels_idx, enlarged_batch_size)
    else:
        hidden1_large = hidden1
        hidden2_large = hidden2
        labels = tf.one_hot(tf.range(batch_size), batch_size * 2)
        masks = tf.one_hot(tf.range(batch_size), batch_size)

    logits_aa = tf.matmul(hidden1, hidden1_large,
                          transpose_b=True) / temperature
    logits_aa = logits_aa - masks * LARGE_NUM
    logits_bb = tf.matmul(hidden2, hidden2_large,
                          transpose_b=True) / temperature
    logits_bb = logits_bb - masks * LARGE_NUM
    logits_ab = tf.matmul(hidden1, hidden2_large,
                          transpose_b=True) / temperature
    logits_ba = tf.matmul(hidden2, hidden1_large,
                          transpose_b=True) / temperature

    logits_a = tf.concat([logits_ab, logits_aa], 1)
    logits_b = tf.concat([logits_ba, logits_bb], 1)

    if FLAGS.loss_func != 'NT-Xent':
        logits_positive = tf.diag_part(logits_ab)
        temp_positive = tf.tile(tf.expand_dims(logits_positive, -1),
                                [1, logits_a.shape[1]])
        masks_a = tf.cast(tf.greater_equal(logits_a, temp_positive - 1e-5),
                          tf.float32)
        masks_b = tf.cast(tf.greater_equal(logits_b, temp_positive - 1e-5),
                          tf.float32)
        logits_a = logits_a - masks_a * LARGE_NUM
        logits_b = logits_b - masks_b * LARGE_NUM
        logits_negative_a = tf.reduce_max(logits_a, axis=1)
        logits_negative_b = tf.reduce_max(logits_b, axis=1)
        #print(logits_negative_a, logits_negative_b)
        if FLAGS.loss_func == 'NT-Logistic':
            loss_a = tf.reduce_mean(
                tf.log(1 + tf.exp(-logits_positive)) +
                tf.log(1 + tf.exp(logits_negative_a)))
            loss_b = tf.reduce_mean(
                tf.log(1 + tf.exp(-logits_positive)) +
                tf.log(1 + tf.exp(logits_negative_b)))
            tf.losses.add_loss(loss_a + loss_b)
            #print(loss_a, loss_b)
            return loss_a + loss_b, logits_ab, labels
        else:
            loss_a = tf.reduce_mean(
                tf.maximum(logits_negative_a - logits_positive + MARGIN, 0))
            loss_b = tf.reduce_mean(
                tf.maximum(logits_negative_b - logits_positive + MARGIN, 0))
            tf.losses.add_loss(loss_a + loss_b)
            return loss_a + loss_b, logits_ab, labels

    loss_a = tf.losses.softmax_cross_entropy(labels, logits_a, weights=weights)
    loss_b = tf.losses.softmax_cross_entropy(labels, logits_b, weights=weights)
    #print(loss_a, loss_b)
    loss = loss_a + loss_b
    return loss, logits_ab, labels
예제 #6
0
def contrastive_loss(hidden, num_replicas, normalize_hidden, temperature,
                     model, weight_decay):
    """Computes contrastive loss.

  Args:
    hidden: embedding of video clips after projection head.
    num_replicas: number of distributed replicas.
    normalize_hidden: whether or not to l2 normalize the hidden vector.
    temperature: temperature in the InfoNCE contrastive loss.
    model: keras model for calculating weight decay.
    weight_decay: weight decay parameter.

  Returns:
    A loss scalar.
    The logits for contrastive prediction task.
    The labels for contrastive prediction task.
  """
    large_num = 1e9

    hidden1, hidden2 = tf.split(hidden, num_or_size_splits=2, axis=0)
    if normalize_hidden:
        hidden1 = tf.math.l2_normalize(hidden1, -1)
        hidden2 = tf.math.l2_normalize(hidden2, -1)
    batch_size = tf.shape(hidden1)[0]

    if num_replicas == 1:
        # This is the local version
        hidden1_large = hidden1
        hidden2_large = hidden2
        labels = tf.one_hot(tf.range(batch_size), batch_size * 2)
        masks = tf.one_hot(tf.range(batch_size), batch_size)

    else:
        # This is the cross-tpu version.
        hidden1_large = tpu_cross_replica_concat(hidden1, num_replicas)
        hidden2_large = tpu_cross_replica_concat(hidden2, num_replicas)
        enlarged_batch_size = tf.shape(hidden1_large)[0]
        replica_id = tf.cast(tf.cast(xla.replica_id(), tf.uint32), tf.int32)
        labels_idx = tf.range(batch_size) + replica_id * batch_size
        labels = tf.one_hot(labels_idx, enlarged_batch_size * 2)
        masks = tf.one_hot(labels_idx, enlarged_batch_size)

    logits_aa = tf.matmul(hidden1, hidden1_large,
                          transpose_b=True) / temperature
    logits_aa = logits_aa - tf.cast(masks, logits_aa.dtype) * large_num
    logits_bb = tf.matmul(hidden2, hidden2_large,
                          transpose_b=True) / temperature
    logits_bb = logits_bb - tf.cast(masks, logits_bb.dtype) * large_num
    logits_ab = tf.matmul(hidden1, hidden2_large,
                          transpose_b=True) / temperature
    logits_ba = tf.matmul(hidden2, hidden1_large,
                          transpose_b=True) / temperature

    loss_a = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(
            labels, tf.concat([logits_ab, logits_aa], 1)))
    loss_b = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(
            labels, tf.concat([logits_ba, logits_bb], 1)))
    loss = loss_a + loss_b

    l2_loss = weight_decay * tf.add_n([
        tf.nn.l2_loss(v)
        for v in model.trainable_variables if 'kernel' in v.name
    ])

    total_loss = loss + tf.cast(l2_loss, loss.dtype)

    contrast_prob = tf.nn.softmax(logits_ab)
    contrast_entropy = -tf.reduce_mean(
        tf.reduce_sum(contrast_prob * tf.math.log(contrast_prob + 1e-8), -1))

    contrast_acc = tf.equal(tf.argmax(labels, 1), tf.argmax(logits_ab, axis=1))
    contrast_acc = tf.reduce_mean(tf.cast(contrast_acc, tf.float32))

    return {
        'total_loss': total_loss,
        'contrastive_loss': loss,
        'reg_loss': l2_loss,
        'contrast_acc': contrast_acc,
        'contrast_entropy': contrast_entropy,
    }
예제 #7
0
def tpu_id():
    # TODO(iamtingchen): more elegant way to convert u32 to s32 for replica_id.
    replica_id = tf.cast(tf.cast(xla.replica_id(), tf.uint32), tf.int32)
    return replica_id
예제 #8
0
def add_contrastive_loss_multi_aug(hidden,
                                   hidden_norm=True,
                                   temperature=1.0,
                                   tpu_context=None,
                                   weights=1.0):
    """Compute loss for model.

  Args:
    hidden: hidden vector (`Tensor`) of shape (bsz, dim).
    hidden_norm: whether or not to use normalization on the hidden vector.
    temperature: a `floating` number for temperature scaling.
    tpu_context: context information for tpu.
    weights: a weighting number or vector.

  Returns:
    A loss scalar.
    The logits for contrastive prediction task.
    The labels for contrastive prediction task.
  """
    # Get (normalized) hidden1 and hidden2.
    if hidden_norm:
        hidden = tf.math.l2_normalize(hidden, -1)
    hiddens = tf.split(hidden, FLAGS.num_transforms, 0)
    batch_size = tf.shape(hiddens[0])[0]

    if tpu_context is None:
        raise NotImplementedError('GPU not supported')

    hiddens_large = [tpu_cross_replica_concat(hidden, tpu_context)\
                     for hidden in hiddens]
    enlarged_batch_size = tf.shape(hiddens_large[0])[0]
    replica_id = tf.cast(tf.cast(xla.replica_id(), tf.uint32), tf.int32)
    labels_idx = tf.range(batch_size) + replica_id * batch_size
    labels = tf.one_hot(labels_idx, enlarged_batch_size * FLAGS.num_transforms)
    masks = tf.one_hot(labels_idx, enlarged_batch_size)

    if FLAGS.adjust_temp:
        _aug_cloud_density = []
        for _idx1 in range(FLAGS.num_transforms):
            for _idx2 in range(_idx1 + 1, FLAGS.num_transforms):
                _dot_pd_result = tf.matmul(hiddens[_idx1],
                                           hiddens[_idx2],
                                           transpose_b=True) / temperature
                if FLAGS.exp_to_adjust_temp:
                    _dot_pd_result = tf.exp(_dot_pd_result)
                _dot_pd_result = tf.linalg.diag_part(_dot_pd_result)
                _aug_cloud_density.append(
                    tf.expand_dims(_dot_pd_result, axis=1))
        _aug_cloud_density = tf.concat(_aug_cloud_density, axis=1)
        _aug_cloud_density = tf.reduce_mean(_aug_cloud_density, axis=1)
        _large_aug_cloud_density = tpu_cross_replica_concat(
            _aug_cloud_density, tpu_context)
        mean_aug_cloud_density = tf.math.reduce_mean(_large_aug_cloud_density)
        std_aug_cloud_density = tf.math.reduce_std(_large_aug_cloud_density)
        if not FLAGS.rn_avg_for_temp:
            aug_cloud_density = \
                (_aug_cloud_density - mean_aug_cloud_density) / std_aug_cloud_density
        else:
            run_mean = tf.get_variable(name="run_mean",
                                       shape=(),
                                       dtype=tf.float32,
                                       trainable=False,
                                       initializer=tf.zeros_initializer())
            run_std = tf.get_variable(name="run_std",
                                      shape=(),
                                      dtype=tf.float32,
                                      trainable=False,
                                      initializer=tf.ones_initializer())
            new_mean = run_mean * 0.99 + mean_aug_cloud_density * 0.01
            new_std = run_std * 0.99 + std_aug_cloud_density * 0.01
            with tf.control_dependencies(
                [run_mean.assign(new_mean),
                 run_std.assign(new_std)]):
                aug_cloud_density = \
                    (_aug_cloud_density - new_mean) / new_std

    all_logits = [[] for _ in range(FLAGS.num_transforms)]
    for _idx1 in range(FLAGS.num_transforms):
        for _idx2 in range(FLAGS.num_transforms):
            _logits = tf.matmul(hiddens[_idx1],
                                hiddens_large[_idx2],
                                transpose_b=True) / temperature
            if FLAGS.adjust_temp:
                if not FLAGS.adjust_temp_params:
                    new_temp = temperature + FLAGS.adjust_temp_std * aug_cloud_density
                    new_temp = tf.clip_by_value(new_temp, 0.05, 0.15)
                else:
                    neg_std, pos_std, neg_bnd, pos_bnd \
                        = [float(_each_param) \
                           for _each_param in FLAGS.adjust_temp_params.split(',')]
                    aug_pos_flag = tf.cast(aug_cloud_density > 0, tf.float32)
                    new_temp = temperature \
                        + neg_std * aug_cloud_density * (1-aug_pos_flag) \
                        + pos_std * aug_cloud_density * aug_pos_flag
                    new_temp = tf.clip_by_value(new_temp, neg_bnd, pos_bnd)
                new_temp = tf.expand_dims(new_temp, axis=1)
                _logits = tf.matmul(hiddens[_idx1],
                                    hiddens_large[_idx2],
                                    transpose_b=True) / new_temp
            all_logits[_idx1].append(_logits)

    loss = 0
    for _idx1 in range(FLAGS.num_transforms):
        for _idx2 in range(FLAGS.num_transforms):
            if _idx2 == _idx1:
                continue
            _logits = [all_logits[_idx1][_idx2]]
            for _idx3 in range(1, FLAGS.num_transforms):
                new_idx2 = (_idx2 + _idx3) % FLAGS.num_transforms
                _logits.append(all_logits[_idx1][new_idx2] - masks * LARGE_NUM)
            _logits = tf.concat(_logits, 1)
            _loss = tf.losses.softmax_cross_entropy(labels,
                                                    _logits,
                                                    weights=weights)
            loss += _loss
    return loss, all_logits[0][1], labels
예제 #9
0
def make_distributed_tensor(strategy, tensors):
    stacked = tf.stack(tensors, axis=0)
    fn = tf.function(lambda t: t[xla.replica_id()])
    return strategy.run(fn, args=(stacked, ))
def add_contrastive_loss(hidden,
                         hidden_norm=True,
                         temperature=1.0,
                         tpu_context=None,
                         weights=1.0,
                         loss_type=None,
                         flags=None):
    """Compute loss for model.

  Args:
    hidden: hidden vector (`Tensor`) of shape (bsz, dim).
    hidden_norm: whether or not to use normalization on the hidden vector.
    temperature: a `floating` number for temperature scaling.
    tpu_context: context information for tpu.
    weights: a weighting number or vector.

  Returns:
    A loss scalar.
    The logits for contrastive prediction task.
    The labels for contrastive prediction task.
  """
    # Get (normalized) hidden1 and hidden2.
    if hidden_norm:
        hidden = tf.math.l2_normalize(hidden, -1)
    hidden1, hidden2 = tf.split(hidden, 2, 0)
    batch_size = tf.shape(hidden1)[0]

    # Gather hidden1/hidden2 across replicas and create local labels.
    if tpu_context is not None:
        hidden1_large = tpu_cross_replica_concat(hidden1, tpu_context)
        hidden2_large = tpu_cross_replica_concat(hidden2, tpu_context)
        enlarged_batch_size = tf.shape(hidden1_large)[0]
        replica_id = tf.cast(tf.cast(xla.replica_id(), tf.uint32), tf.int32)
        labels_idx = tf.range(batch_size) + replica_id * batch_size
        labels = tf.one_hot(labels_idx, enlarged_batch_size * 2)
        masks = tf.one_hot(labels_idx, enlarged_batch_size)
    else:
        hidden1_large = hidden1
        hidden2_large = hidden2
        labels = tf.one_hot(tf.range(batch_size), batch_size * 2)
        masks = tf.one_hot(tf.range(batch_size), batch_size)
    # check WPC
    if loss_type.lower() == "wpc":
        assert flags.gradient_penalty_weight != 0.0
    else:
        assert flags.gradient_penalty_weight == 0.0
    logits_aa = tf.matmul(hidden1, hidden1_large,
                          transpose_b=True) / temperature
    logits_bb = tf.matmul(hidden2, hidden2_large,
                          transpose_b=True) / temperature
    # aa and bb diagonals are not positive samples; positive samples are ab abd ba diagnoals
    # thus we want to mask aa and bb diagonals out
    if loss_type.lower() == "nce" or loss_type.lower(
    ) == "dv" or loss_type.lower() == "wpc":
        # NCE loss: minus big number to create cloes to 0 values in softmax
        print(loss_type)
        logits_aa = logits_aa - masks * LARGE_NUM
        logits_bb = logits_bb - masks * LARGE_NUM
    else:  # otherwise just mask out using 0
        logits_aa = logits_aa * (1 - masks)
        logits_bb = logits_bb * (1 - masks)
    logits_ab = tf.matmul(hidden1, hidden2_large,
                          transpose_b=True) / temperature
    logits_ba = tf.matmul(hidden2, hidden1_large,
                          transpose_b=True) / temperature
    #############################################################################
    ### Different losses: nce, chi, js and nwj
    ### Pos_scores: positive samples, i.e. joint distribution terms
    ### neg_scores: negative samples, i.e. marginal distribution terms
    #############################################################################
    if loss_type.lower() == "nce":
        loss_a = tf.losses.softmax_cross_entropy(
            labels, tf.concat([logits_ab, logits_aa], 1), weights=weights)
        loss_b = tf.losses.softmax_cross_entropy(
            labels, tf.concat([logits_ba, logits_bb], 1), weights=weights)
    elif loss_type.lower() == "chi":
        ## Chi squared loss in general form
        alpha = flags.alpha
        beta = flags.beta
        gamma = flags.gamma

        joint_a = labels * tf.concat([logits_ab, logits_aa], 1)
        joint_b = labels * tf.concat([logits_ba, logits_bb], 1)
        # non-correlated views
        marg_a = (1. - labels) * tf.concat([logits_ab, logits_aa], 1)
        marg_b = (1. - labels) * tf.concat([logits_ba, logits_bb], 1)
        batch_size = tf.cast(batch_size, tf.float32)
        joint = 0.5*(tf.reduce_sum(joint_a - 0.5 * beta * joint_a**2) /  batch_size)\
                + 0.5*(tf.reduce_sum(joint_b - 0.5 * beta * joint_b**2) /  batch_size)
        # non-correlated views
        marg = 0.5*(tf.reduce_sum(alpha * marg_a + 0.5 * gamma * marg_a**2) /  (2*batch_size*(batch_size-1.)))\
                + 0.5*(tf.reduce_sum(alpha * marg_b + 0.5 * gamma * marg_b**2) /  (2*batch_size*(batch_size-1.)))
        loss = -1. * (joint - marg)
        tf.losses.add_loss(loss)
        return loss, logits_ab, labels

    elif loss_type.lower() == "js":
        # Jensen Shannon
        def js(logits_concat, labels, scope=None):
            lbls = math_ops.cast(labels, logits_concat.dtype)
            """SHOULD I ADD STOP GRADIENT?"""
            bs = math_ops.cast(batch_size, logits_concat.dtype)
            pos_scores = tf.reduce_sum(
                lbls * (-tf.math.softplus(-logits_concat))) / bs
            neg_scores = tf.reduce_sum(
                (1 - lbls) * tf.math.softplus(logits_concat)) / (
                    (2 * bs - 1) * bs)
            return -(pos_scores - neg_scores)

        loss_a = 0.5 * js(tf.concat([logits_ab, logits_aa], 1), labels)
        loss_b = 0.5 * js(tf.concat([logits_ba, logits_bb], 1), labels)
        tf.losses.add_loss(loss_a)
        tf.losses.add_loss(loss_b)

    elif loss_type.lower() == "nwj":

        def nwj(logits_concat, labels, scope=None):
            lbls = math_ops.cast(labels, logits_concat.dtype)
            """SHOULD I ADD STOP GRADIENT?"""
            bs = math_ops.cast(batch_size, logits_concat.dtype)
            pos_scores = tf.reduce_sum(lbls * logits_concat) / bs
            neg_scores = tf.reduce_sum(
                (1 - lbls) * tf.math.exp(logits_concat - 1)) / (
                    (2 * bs - 1) * bs)
            return -(pos_scores - neg_scores)

        loss_a = 0.5 * nwj(tf.concat([logits_ab, logits_aa], 1), labels)
        loss_b = 0.5 * nwj(tf.concat([logits_ba, logits_bb], 1), labels)
        tf.losses.add_loss(loss_a)
        tf.losses.add_loss(loss_b)
    elif loss_type.lower() == "dv":
        # Donsker and Varadhan
        def dv(logits_concat, labels, scope=None):
            lbls = math_ops.cast(labels, logits_concat.dtype)
            """SHOULD I ADD STOP GRADIENT?"""
            bs = math_ops.cast(batch_size, logits_concat.dtype)
            pos_scores = tf.reduce_sum(lbls * logits_concat) / bs
            neg_scores = tf.math.reduce_logsumexp(
                (1 - lbls) * logits_concat) - tf.math.log((2 * bs - 1) * bs)
            return -(pos_scores - neg_scores)

        loss_a = 0.5 * dv(tf.concat([logits_ab, logits_aa], 1), labels)
        loss_b = 0.5 * dv(tf.concat([logits_ba, logits_bb], 1), labels)
        tf.losses.add_loss(loss_a)
        tf.losses.add_loss(loss_b)
    elif loss_type.lower() == "wpc":
        # Wasserstein Dependency Measure (i.e. Wasserstein Predictive Coding)
        # Adding soon
        pass  # operation performed in model.py

    else:
        raise ValueError(
            "Loss not implemented yet; only support {nce, chi, js, nwj}")

    loss = loss_a + loss_b
    return loss, logits_ab, labels
예제 #11
0
def td_attractive_repulsive_loss(reconstruction,
                                      target,
                                      hidden_norm=True,
                                      power=2,
                                      temperature=1.0,
                                      tpu_context=None,
                                      weights=1.0):
  """Compute top down attractive and repulsive loss base on pixel-wise error.
  Args:
    hidden: hidden vector (`Tensor`) of shape (bsz, dim).
    hidden_norm: whether or not to use normalization on the hidden vector.
    temperature: a `floating` number for temperature scaling.
    tpu_context: context information for tpu.
    weights: a weighting number or vector.
  Returns:
    A loss scalar.
    The logits for contrastive prediction task.
    The labels for contrastive prediction task.
  """
  # Get (normalized) hidden1 and hidden2.

  if hidden_norm:
    reconstruction = tf.math.l2_normalize(reconstruction, -1)
    target = tf.math.l2_normalize(target, -1)

  batch_size = target.get_shape().as_list()[0]
  rec_size = reconstruction.get_shape().as_list()[0]
  
  if batch_size != rec_size:

    reconstruction1, reconstruction2 = tf.split(reconstruction, 2, 0)
    # batch_size = tf.shape(reconstruction1)[0]

    # Gather hidden1/hidden2 across replicas and create local labels.
    if tpu_context is not None:
      reconstruction1_large = tpu_cross_replica_concat(reconstruction1, tpu_context)
      reconstruction2_large = tpu_cross_replica_concat(reconstruction2, tpu_context)
      target_large = tpu_cross_replica_concat(target, tpu_context)
      
      enlarged_batch_size = tf.shape(reconstruction1_large)[0]
      # TODO(iamtingchen): more elegant way to convert u32 to s32 for replica_id.
      replica_id = tf.cast(tf.cast(xla.replica_id(), tf.uint32), tf.int32)
      labels_idx = tf.range(batch_size) + replica_id * batch_size
      labels = tf.one_hot(labels_idx, enlarged_batch_size * 3)
      masks = tf.one_hot(labels_idx, enlarged_batch_size)
    else:
      reconstruction1_large = reconstruction1
      reconstruction2_large = reconstruction2
      labels = tf.one_hot(tf.range(batch_size), batch_size * 3)
      masks = tf.one_hot(tf.range(batch_size), batch_size)

    # target = tf.reshape(target, [batch_size, -1])
    reconstruction1 = tf.reshape(reconstruction1, [batch_size, -1])
    reconstruction2 = tf.reshape(reconstruction2, [batch_size, -1])
    
    target_large = tf.reshape(target_large, [enlarged_batch_size, -1])
    reconstruction1_large = tf.reshape(reconstruction1_large, [enlarged_batch_size, -1])
    reconstruction2_large = tf.reshape(reconstruction2_large, [enlarged_batch_size, -1])

    logits_at = tf.matmul(reconstruction1, target_large, transpose_b=True) / temperature
    logits_bt = tf.matmul(reconstruction2, target_large, transpose_b=True) / temperature

    logits_aa = tf.matmul(reconstruction1, reconstruction1_large, transpose_b=True) / temperature
    logits_aa = logits_aa - masks * LARGE_NUM

    logits_bb = tf.matmul(reconstruction2, reconstruction2_large, transpose_b=True) / temperature
    logits_bb = logits_bb - masks * LARGE_NUM

    logits_ab = tf.matmul(reconstruction1, reconstruction2_large, transpose_b=True) / temperature
    logits_ba = tf.matmul(reconstruction2, reconstruction1_large, transpose_b=True) / temperature
    
    loss_a = tf.losses.softmax_cross_entropy(
        labels, tf.concat([logits_at, logits_aa, logits_ab], 1), weights=weights)
    loss_b = tf.losses.softmax_cross_entropy(
        labels, tf.concat([logits_bt, logits_ba, logits_bb], 1), weights=weights)
    loss = loss_a + loss_b
  
  else:
    # Gather hidden1/hidden2 across replicas and create local labels.
    if tpu_context is not None:
      reconstruction_large = tpu_cross_replica_concat(reconstruction, tpu_context)
      target_large = tpu_cross_replica_concat(target, tpu_context)
      
      enlarged_batch_size = tf.shape(reconstruction_large)[0]
      # TODO(iamtingchen): more elegant way to convert u32 to s32 for replica_id.
      replica_id = tf.cast(tf.cast(xla.replica_id(), tf.uint32), tf.int32)
      labels_idx = tf.range(batch_size) + replica_id * batch_size
      labels = tf.one_hot(labels_idx, enlarged_batch_size * 2)
      masks = tf.one_hot(labels_idx, enlarged_batch_size)
    else:
      reconstruction_large = reconstruction
      labels = tf.one_hot(tf.range(batch_size), batch_size * 2)
      masks = tf.one_hot(tf.range(batch_size), batch_size)

    reconstruction = tf.reshape(reconstruction, [batch_size, -1])
    
    target_large = tf.reshape(target_large, [enlarged_batch_size, -1])
    reconstruction_large = tf.reshape(reconstruction_large, [enlarged_batch_size, -1])
    
    logits_at = tf.matmul(reconstruction, target_large, transpose_b=True) / temperature
    logits_aa = tf.matmul(reconstruction, reconstruction_large, transpose_b=True) / temperature
    logits_aa = logits_aa - masks * LARGE_NUM
    

    loss = tf.losses.softmax_cross_entropy(
        labels, tf.concat([logits_at, logits_aa], 1), weights=weights) #tf.concat([logits_at, logits_aa], 1)

  return loss, logits_at, labels
예제 #12
0
def add_contrastive_loss(hidden,
                         hidden_norm=True,
                         temperature=1.0,
                         tpu_context=None,
                         weights=1.0):
    """Compute loss for model.

  Args:
    hidden: hidden vector (`Tensor`) of shape (2 * bsz, dim).
    hidden_norm: whether or not to use normalization on the hidden vector.
    temperature: a `floating` number for temperature scaling.
    tpu_context: context information for tpu.
    weights: a weighting number or vector.

  Returns:
    A loss scalar.
    The logits for contrastive prediction task.
    The labels for contrastive prediction task.
  """
    # Get (normalized) hidden1 and hidden2.
    if hidden_norm:
        hidden = tf.math.l2_normalize(hidden, -1)
    hidden1, hidden2 = tf.split(
        hidden, 2, 0
    )  # splits hidden in half along 0 axis (batch size axis?), but should be duplicating hidden??
    batch_size = tf.shape(
        hidden1
    )[0]  # maybe one batch from dataloader = bs/2 images + bs/2 transformed images ??
    # we need to change how hidden1 and hidden2 are calculated so that they are fed through different base models

    # Gather hidden1/hidden2 across replicas and create local labels.
    if tpu_context is not None:
        hidden1_large = tpu_cross_replica_concat(hidden1, tpu_context)
        hidden2_large = tpu_cross_replica_concat(hidden2, tpu_context)
        enlarged_batch_size = tf.shape(hidden1_large)[0]
        # TODO(iamtingchen): more elegant way to convert u32 to s32 for replica_id.
        replica_id = tf.cast(tf.cast(xla.replica_id(), tf.uint32), tf.int32)
        labels_idx = tf.range(batch_size) + replica_id * batch_size
        labels = tf.one_hot(labels_idx, enlarged_batch_size * 2)
        masks = tf.one_hot(labels_idx, enlarged_batch_size)
    else:
        hidden1_large = hidden1
        hidden2_large = hidden2
        labels = tf.one_hot(tf.range(batch_size), batch_size * 2)
        masks = tf.one_hot(tf.range(batch_size), batch_size)

    logits_aa = tf.matmul(hidden1, hidden1_large,
                          transpose_b=True) / temperature
    logits_aa = logits_aa - masks * LARGE_NUM
    logits_bb = tf.matmul(hidden2, hidden2_large,
                          transpose_b=True) / temperature
    logits_bb = logits_bb - masks * LARGE_NUM
    logits_ab = tf.matmul(hidden1, hidden2_large,
                          transpose_b=True) / temperature
    logits_ba = tf.matmul(hidden2, hidden1_large,
                          transpose_b=True) / temperature

    loss_a = tf.losses.softmax_cross_entropy(labels,
                                             tf.concat([logits_ab, logits_aa],
                                                       1),
                                             weights=weights)
    loss_b = tf.losses.softmax_cross_entropy(labels,
                                             tf.concat([logits_ba, logits_bb],
                                                       1),
                                             weights=weights)
    loss = loss_a + loss_b

    return loss, logits_ab, labels