def __init__(self, config): super(GetNextSentenceOutput, self).__init__() self.log_softmax = _selected_ops.LogSoftmax() weight_init = TruncatedNormal(config.initializer_range) self.dense = nn.Dense(config.hidden_size, 2, weight_init=weight_init, has_bias=True).to_float(config.compute_type) self.dtype = config.dtype self.cast = ops.Cast()
def __init__(self, axis=-1): super(LogSoftmax, self).__init__() self.log_softmax = _selected_ops.LogSoftmax(axis)