示例#1
0
def CountWeights(mask_id=None, has_weights=False):
    """Sum the weights assigned to all elements."""
    if has_weights:
        return cb.Serial(
            cb.Drop(),  # Drop inputs.
            WeightMask(mask_id=mask_id),  # pylint: disable=no-value-for-parameter
            cb.Multiply(),  # Multiply with provided mask.
            core.Sum(axis=None)  # Sum all weights.
        )
    return cb.Serial(
        cb.Drop(),  # Drop inputs.
        WeightMask(mask_id=mask_id),  # pylint: disable=no-value-for-parameter
        core.Sum(axis=None)  # Sum all weights.
    )
示例#2
0
def SumOfWeights():
    """Returns a layer that computes sum of weights."""
    return cb.Serial(
        cb.Drop(),  # Drop inputs.
        cb.Drop(),  # Drop targets.
        core.Sum(axis=None)  # Sum weights.
    )
示例#3
0
文件: metrics.py 项目: zhaoqiuye/trax
def SumOfWeights():
    """Returns a layer to compute sum of weights of all non-masked elements."""
    return cb.Serial(
        cb.Drop(),  # Drop inputs.
        cb.Drop(),  # Drop targets.
        core.Sum(axis=None)  # Sum weights.
    )
示例#4
0
文件: metrics.py 项目: zsunpku/trax
def SumOfWeights(id_to_mask=None, has_weights=False):
  """Returns a layer to compute sum of weights of all non-masked elements."""
  multiply_by_weights = cb.Multiply() if has_weights else []
  return cb.Serial(
      cb.Drop(),  # Drop inputs.
      _ElementMask(id_to_mask=id_to_mask),
      multiply_by_weights,
      core.Sum(axis=None)  # Sum all.
  )