示例#1
0
    def _cast_types(self, input_dict):
        """This function performs automatic cast of all inputs to encoder dtype.

    Args:
      input_dict (dict): dictionary passed to :meth:`self._encode() <_encode>`
          method.

    Returns:
      dict: same as input_dict, but with all Tensors cast to encoder dtype.
    """
        return cast_types(input_dict, self.params['dtype'])
示例#2
0
文件: loss.py 项目: fotwo/OpenSeq2Seq
  def _cast_types(self, input_dict):
    """This function performs automatic cast of all inputs to the loss dtype.

    Args:
      input_dict (dict): dictionary passed to
          :meth:`self._compute_loss() <_compute_loss>` method.

    Returns:
      dict: same as input_dict, but with all Tensors cast to the loss dtype.
    """
    return cast_types(input_dict, self.params['dtype'])