Exemplo n.º 1
0
 def __init__(self, config):
     super(GetNextSentenceOutput, self).__init__()
     self.log_softmax = _selected_ops.LogSoftmax()
     weight_init = TruncatedNormal(config.initializer_range)
     self.dense = nn.Dense(config.hidden_size, 2,
                           weight_init=weight_init, has_bias=True).to_float(config.compute_type)
     self.dtype = config.dtype
     self.cast = ops.Cast()
Exemplo n.º 2
0
 def __init__(self, axis=-1):
     super(LogSoftmax, self).__init__()
     self.log_softmax = _selected_ops.LogSoftmax(axis)