Пример #1
0
def linear(args,
           output_size,
           bias,
           bias_start=0.0,
           scope=None,
           squeeze=False,
           wd=0.0,
           input_keep_prob=1.0,
           is_train=None,
           kernel_initializer=None):

    if args is None or (nest.is_sequence(args) and not args):
        raise ValueError("`args` must be specified")
    if not nest.is_sequence(args):
        args = [args]

    flat_args = [flatten(arg, 1)
                 for arg in args]  # flat_args[0] : [N*JX*JQ, d]
    if input_keep_prob < 1.0:
        assert is_train is not None
        flat_args = [
            tf.cond(is_train, lambda: tf.nn.dropout(arg, input_keep_prob),
                    lambda: arg) for arg in flat_args
        ]
    with tf.variable_scope(scope or 'linear'):
        flat_out = _linear(flat_args,
                           output_size,
                           bias,
                           kernel_initializer=kernel_initializer)
    out = reconstruct(flat_out, args[0], 1)
    if squeeze:
        out = tf.squeeze(out, [len(args[0].get_shape().as_list()) - 1])
    if wd:
        add_wd(wd)
    return out
Пример #2
0
def l2_normalize(in_, proj=True, input_keep_prob=1.0, wd=0.0, is_train=None, scope=None):
    d = in_.get_shape().as_list()[-1]
    with tf.variable_scope(scope or "l2_normalize"):
        if proj:
            in_ = F(in_, d, scope='in_', input_keep_prob=input_keep_prob, wd=wd, is_train=is_train, use_bias=True)
        in_2 = flatten(in_, 1)
        out = tf.nn.l2_normalize(in_2, 1)
        out = reconstruct(out, in_, 1)
        return out
Пример #3
0
def softmax(logits, mask=None, scope=None, rescale=False, dim=None):
    with tf.name_scope(scope or "Softmax"):
        if mask is not None:
            logits = exp_mask(logits, mask)
        if rescale:
            assert dim is not None
            logits = tf.divide(logits, tf.ones_like(logits, dtype=tf.float32) * tf.sqrt(dim))
        flat_logits = flatten(logits, 1)
        flat_out = tf.nn.softmax(flat_logits)
        out = reconstruct(flat_out, logits, 1)
        return out