Example #1
0
    def __call__(self, inputs, state, scope=None):
        with tf.variable_scope(scope,
                               default_name="gru_cell",
                               values=[inputs, state]):
            if not isinstance(inputs, (list, tuple)):
                inputs = [inputs]

            all_inputs = list(inputs) + [state]
            r = tf.nn.sigmoid(
                linear(all_inputs,
                       self._num_units,
                       False,
                       False,
                       scope="reset_gate"))
            u = tf.nn.sigmoid(
                linear(all_inputs,
                       self._num_units,
                       False,
                       False,
                       scope="update_gate"))
            all_inputs = list(inputs) + [r * state]
            c = linear(all_inputs,
                       self._num_units,
                       True,
                       False,
                       scope="candidate")

            new_state = (1.0 - u) * state + u * tf.tanh(c)

        return new_state, new_state
Example #2
0
def attention(query,
              memories,
              bias,
              hidden_size,
              cache=None,
              reuse=None,
              dtype=None,
              scope=None):
    """ Standard attention layer

    :param query: A tensor with shape [batch, key_size]
    :param memories: A tensor with shape [batch, memory_size, key_size]
    :param bias: A tensor with shape [batch, memory_size]
    :param hidden_size: An integer
    :param cache: A dictionary of precomputed value
    :param reuse: A boolean value, whether to reuse the scope
    :param dtype: An optional instance of tf.DType
    :param scope: An optional string, the scope of this layer
    :return: A tensor with shape [batch, value_size] and
        a Tensor with shape [batch, memory_size]
    """

    with tf.variable_scope(scope or "attention",
                           reuse=reuse,
                           values=[query, memories, bias],
                           dtype=dtype):
        mem_shape = tf.shape(memories)
        key_size = memories.get_shape().as_list()[-1]

        if cache is None:
            k = tf.reshape(memories, [-1, key_size])
            k = linear(k, hidden_size, False, False, scope="k_transform")

            if query is None:
                return {"key": k}
        else:
            k = cache["key"]

        q = linear(query, hidden_size, False, False, scope="q_transform")
        k = tf.reshape(k, [mem_shape[0], mem_shape[1], hidden_size])

        hidden = tf.tanh(q[:, None, :] + k)
        hidden = tf.reshape(hidden, [-1, hidden_size])

        # Shape: [batch, mem_size, 1]
        logits = linear(hidden, 1, False, False, scope="logits")
        logits = tf.reshape(logits, [-1, mem_shape[1]])

        if bias is not None:
            logits = logits + bias

        alpha = tf.nn.softmax(logits)

        outputs = {
            "value": tf.reduce_sum(alpha[:, :, None] * memories, axis=1),
            "weight": alpha
        }

    return outputs
Example #3
0
def additive_attention(queries, keys, values, bias, hidden_size, concat=False,
                       keep_prob=None, dtype=None, scope=None):
    """ Additive attention mechanism. This layer is implemented using a
        one layer feed forward neural network

    :param queries: A tensor with shape [batch, heads, length_q, depth_k]
    :param keys: A tensor with shape [batch, heads, length_kv, depth_k]
    :param values: A tensor with shape [batch, heads, length_kv, depth_v]
    :param bias: A tensor
    :param hidden_size: An integer
    :param concat: A boolean value. If ``concat'' is set to True, then
        the computation of attention mechanism is following $tanh(W[q, k])$.
        When ``concat'' is set to False, the computation is following
        $tanh(Wq + Vk)$
    :param keep_prob: a scalar in [0, 1]
    :param dtype: An optional instance of tf.DType
    :param scope: An optional string, the scope of this layer

    :returns: A dict with the following keys:
        weights: A tensor with shape [batch, length_q]
        outputs: A tensor with shape [batch, length_q, depth_v]
    """

    with tf.variable_scope(scope, default_name="additive_attention",
                           values=[queries, keys, values, bias], dtype=dtype):
        length_q = tf.shape(queries)[2]
        length_kv = tf.shape(keys)[2]
        q = tf.tile(tf.expand_dims(queries, 3), [1, 1, 1, length_kv, 1])
        k = tf.tile(tf.expand_dims(keys, 2), [1, 1, length_q, 1, 1])

        if concat:
            combined = tf.tanh(linear(tf.concat([q, k], axis=-1), hidden_size,
                                      True, True, name="qk_transform"))
        else:
            q = linear(queries, hidden_size, True, True, name="q_transform")
            k = linear(keys, hidden_size, True, True, name="key_transform")
            combined = tf.tanh(q + k)

        # shape: [batch, heads, length_q, length_kv]
        logits = tf.squeeze(linear(combined, 1, True, True, name="logits"),
                            axis=-1)

        if bias is not None:
            logits += bias

        weights = tf.nn.softmax(logits, name="attention_weights")

        if keep_prob or keep_prob < 1.0:
            weights = tf.nn.dropout(weights, keep_prob)

        outputs = tf.matmul(weights, values)

        return {"weights": weights, "outputs": outputs}
Example #4
0
def attention(query,
              memories,
              bias,
              hidden_size,
              cache=None,
              reuse=None,
              dtype=None,
              scope=None):

    with tf.variable_scope(scope or "attention",
                           reuse=reuse,
                           values=[query, memories, bias],
                           dtype=dtype):
        mem_shape = tf.shape(memories)
        key_size = memories.get_shape().as_list()[-1]

        if cache is None:
            k = tf.reshape(memories, [-1, key_size])
            k = linear(k, hidden_size, False, False, scope="k_transform")

            if query is None:
                return {"key": k}
        else:
            k = cache["key"]

        q = linear(query, hidden_size, False, False, scope="q_transform")
        k = tf.reshape(k, [mem_shape[0], mem_shape[1], hidden_size])

        hidden = tf.tanh(q[:, None, :] + k)
        hidden = tf.reshape(hidden, [-1, hidden_size])

        logits = linear(hidden, 1, False, False, scope="logits")
        logits = tf.reshape(logits, [-1, mem_shape[1]])

        if bias is not None:
            logits = logits + bias

        alpha = tf.nn.softmax(logits)

        outputs = {
            "value": tf.reduce_sum(alpha[:, :, None] * memories, axis=1),
            "weight": alpha
        }

    return outputs
Example #5
0
def additive_attention(queries,
                       keys,
                       values,
                       bias,
                       hidden_size,
                       concat=False,
                       keep_prob=None,
                       dtype=None,
                       scope=None):

    with tf.variable_scope(scope,
                           default_name="additive_attention",
                           values=[queries, keys, values, bias],
                           dtype=dtype):
        length_q = tf.shape(queries)[2]
        length_kv = tf.shape(keys)[2]
        q = tf.tile(tf.expand_dims(queries, 3), [1, 1, 1, length_kv, 1])
        k = tf.tile(tf.expand_dims(keys, 2), [1, 1, length_q, 1, 1])

        if concat:
            combined = tf.tanh(
                linear(tf.concat([q, k], axis=-1),
                       hidden_size,
                       True,
                       True,
                       name="qk_transform"))
        else:
            q = linear(queries, hidden_size, True, True, name="q_transform")
            k = linear(keys, hidden_size, True, True, name="key_transform")
            combined = tf.tanh(q + k)

        logits = tf.squeeze(linear(combined, 1, True, True, name="logits"),
                            axis=-1)

        if bias is not None:
            logits += bias

        weights = tf.nn.softmax(logits, name="attention_weights")

        if keep_prob or keep_prob < 1.0:
            weights = tf.nn.dropout(weights, keep_prob)

        outputs = tf.matmul(weights, values)

        return {"weights": weights, "outputs": outputs}
Example #6
0
    def __call__(self, inputs, state, scope=None):

        with tf.variable_scope(scope, default_name="lstm_cell",
                               values=[inputs, state]):

            i = tf.nn.sigmoid(linear(inputs, self._num_units, False, False,
                                     scope="input_gate"))
            f = tf.nn.sigmoid(linear(inputs, self._num_units, False, False,
                                     scope="forget_gate"))
            o = tf.nn.sigmoid(linear(inputs, self._num_units, False, False,
                                     scope="outut_gate"))
            c_ = tf.nn.sigmoid(linear(inputs, self._num_units, False, False,
                                     scope="cell_gate"))

            new_state = f * state + i * c_

            new_inputs = o * tf.tanh(new_state)

        return new_inputs, new_state
Example #7
0
def policy_network(queries, memories, size, num_heads, mode, soft_select=False, keep_prob=None, scope=None):
    with tf.variable_scope(scope, default_name="policy_network", 
                            values=[queries, memories]):
        with tf.variable_scope("input_layer"):
            if memories is None:
                combined = linear(queries, 2 * size, True, True, scope="qk_transform")
                q, k = tf.split(combined, [size, size], axis=-1) 
            else:
                q = linear(queries, size, True, True, scope="q_transform")
                k = linear(memories, size, True, scope="k_transform")
            
            #q = tf.tanh(q)
            #k = tf.tanh(k)

            q = split_heads(q, num_heads)
            k = split_heads(k, num_heads)

            #q = linear(q, size, True, True, scope="q2_transform")
            #k = linear(k, size, True, True, scope="k2_transform")
            
            logits = tf.matmul(q, k, transpose_b=True)
            prob = tf.nn.sigmoid(logits, name="policy")

        with tf.variable_scope("output_layer"):
            if mode == 'train':
                if soft_select == False:
                    sample_prob = RelaxedBernoulli(0.5, logits=logits)
                    y = sample_prob.sample()
                    y_hard = tf.cast(tf.greater(y, 0.5), y.dtype)
                    y = tf.stop_gradient(y_hard - y) + y
                else:
                    y = prob
            else:
                #prob = tf.Print(prob, [tf.reduce_mean(prob), prob], summarize=100000)
                y = tf.cast(tf.greater(prob, 0.5), prob.dtype)
        nce = -tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=prob)
        return y, prob, nce
Example #8
0
def multihead_attention(queries,
                        memories,
                        bias,
                        num_heads,
                        key_size,
                        value_size,
                        output_size,
                        keep_prob=None,
                        output=True,
                        dtype=None,
                        scope=None):
    """ Multi-head scaled-dot-product attention with input/output
        transformations.

    :param queries: A tensor with shape [batch, length_q, depth_q] if
    :param memories: A tensor with shape [batch, length_m, depth_m]
    :param bias: A tensor (see attention_bias)
    :param num_heads: An integer dividing key_size and value_size
    :param key_size: An integer
    :param value_size: An integer
    :param output_size: An integer
    :param keep_prob: A floating point number in (0, 1]
    :param output: Whether to use output transformation
    :param dtype: An optional instance of tf.DType
    :param scope: An optional string

    :returns: A dict with the following keys:
        weights: A tensor with shape [batch, length_q]
        outputs: A tensor with shape [batch, length_q, depth_v]
    """

    if key_size % num_heads != 0:
        raise ValueError("Key size (%d) must be divisible by the number of "
                         "attention heads (%d)." % (key_size, num_heads))

    if value_size % num_heads != 0:
        raise ValueError("Value size (%d) must be divisible by the number of "
                         "attention heads (%d)." % (value_size, num_heads))

    with tf.variable_scope(scope,
                           default_name="multihead_attention",
                           values=[queries, memories],
                           dtype=dtype):
        if memories is None:
            # self attention
            size = key_size * 2 + value_size
            combined = linear(queries, size, True, True, scope="qkv_transform")
            q, k, v = tf.split(combined, [key_size, key_size, value_size],
                               axis=-1)
        else:
            q = linear(queries, key_size, True, True, scope="q_transform")
            combined = linear(memories,
                              key_size + value_size,
                              True,
                              scope="kv_transform")
            k, v = tf.split(combined, [key_size, value_size], axis=-1)

        # split heads
        q = split_heads(q, num_heads)
        k = split_heads(k, num_heads)
        v = split_heads(v, num_heads)

        # scale query
        key_depth_per_head = key_size // num_heads
        q *= key_depth_per_head**-0.5

        # attention
        results = multiplicative_attention(q, k, v, bias, keep_prob)

        # combine heads
        weights = results["weights"]
        x = combine_heads(results["outputs"])

        if output:
            outputs = linear(x,
                             output_size,
                             True,
                             True,
                             scope="output_transform")
        else:
            outputs = x

        return {"weights": weights, "outputs": outputs}
Example #9
0
def attention_mhead(query,
                    memories,
                    bias,
                    hidden_size,
                    num_heads=16,
                    cache=None,
                    reuse=None,
                    dtype=None,
                    scope=None):
    """ Standard attention layer

    :param query: A tensor with shape [batch, key_size]
    :param memories: A tensor with shape [batch, memory_size, key_size]
    :param bias: A tensor with shape [batch, memory_size]
    :param hidden_size: An integer
    :param cache: A dictionary of precomputed value
    :param reuse: A boolean value, whether to reuse the scope
    :param dtype: An optional instance of tf.DType
    :param scope: An optional string, the scope of this layer
    :return: A tensor with shape [batch, value_size] and
        a Tensor with shape [batch, memory_size]
    """

    with tf.variable_scope(scope or "attention_mhead",
                           reuse=reuse,
                           values=[query, memories, bias],
                           dtype=dtype):
        mem_shape = tf.shape(memories)
        key_size = memories.get_shape().as_list()[-1]

        if cache is None:
            k = tf.reshape(memories, [-1, key_size])
            k = linear(k, hidden_size, False, False, scope="k_transform")

            if query is None:
                return {"key": k}
        else:
            k = cache["key"]

        q = linear(query, hidden_size, False, False, scope="q_transform")
        k = tf.reshape(k, [mem_shape[0], mem_shape[1], hidden_size])

        # split heads
        q = split_heads(q[:, None, :], num_heads)
        k = split_heads(k, num_heads)

        # scale query
        #key_depth_per_head = hidden_size // num_heads
        #q *= key_depth_per_head ** -0.5

        hidden = tf.tanh(q + k)
        #hidden = tf.reshape(hidden, [-1, hidden_size])

        # Shape: [batch, num_heads, mem_size, 1]
        logits = linear(hidden, 1, False, False, scope="logits")
        logits = tf.reshape(logits, [mem_shape[0], num_heads, mem_shape[1]])

        if bias is not None:
            logits = logits + bias

        alpha = tf.nn.softmax(logits)

        memories_split = split_heads(memories, num_heads)
        x = combine_heads(alpha[:, :, :, None] * memories_split)

        y = linear(tf.reduce_sum(x, axis=1),
                   hidden_size * 2,
                   True,
                   True,
                   scope="output_transform")

        outputs = {"value": y, "weight": alpha}

    return outputs
Example #10
0
def multihead_attention(queries, memories, bias, num_heads, key_size,
                        value_size, output_size, keep_prob=None, output=True,
                        state=None, summary=True, dtype=None, scope=None, scaled_bias=None):
    """ Multi-head scaled-dot-product attention with input/output
        transformations.

    :param queries: A tensor with shape [batch, length_q, depth_q]
    :param memories: A tensor with shape [batch, length_m, depth_m]
    :param bias: A tensor (see attention_bias)
    :param num_heads: An integer dividing key_size and value_size
    :param key_size: An integer
    :param value_size: An integer
    :param output_size: An integer
    :param keep_prob: A floating point number in (0, 1]
    :param output: Whether to use output transformation
    :param state: An optional dictionary used for incremental decoding
    :param summary: Use image summary
    :param dtype: An optional instance of tf.DType
    :param scope: An optional string

    :returns: A dict with the following keys:
        weights: A tensor with shape [batch, heads, length_q, length_kv]
        outputs: A tensor with shape [batch, length_q, depth_v]
    """

    if key_size % num_heads != 0:
        raise ValueError("Key size (%d) must be divisible by the number of "
                         "attention heads (%d)." % (key_size, num_heads))

    if value_size % num_heads != 0:
        raise ValueError("Value size (%d) must be divisible by the number of "
                         "attention heads (%d)." % (value_size, num_heads))

    with tf.variable_scope(scope, default_name="multihead_attention",
                           values=[queries, memories], dtype=dtype):
        next_state = {}

        if memories is None:
            # self attention
            size = key_size * 2 + value_size
            combined = linear(queries, size, True, True, scope="qkv_transform")
            q, k, v = tf.split(combined, [key_size, key_size, value_size],
                               axis=-1)

            if state is not None:
                k = tf.concat([state["key"], k], axis=1)
                v = tf.concat([state["value"], v], axis=1)
                next_state["key"] = k
                next_state["value"] = v
        else:
            q = linear(queries, key_size, True, True, scope="q_transform")
            combined = linear(memories, key_size + value_size, True,
                              scope="kv_transform")
            k, v = tf.split(combined, [key_size, value_size], axis=-1)

        # split heads
        q = split_heads(q, num_heads)
        k = split_heads(k, num_heads)
        v = split_heads(v, num_heads)

        # scale query
        key_depth_per_head = key_size // num_heads
        q *= key_depth_per_head ** -0.5

        # attention
        results = multiplicative_attention(q, k, v, bias, keep_prob, scaled_bias=scaled_bias)

        # combine heads
        weights = results["weights"]
        x = combine_heads(results["outputs"])

        if output:
            outputs = linear(x, output_size, True, True,
                             scope="output_transform")
        else:
            outputs = x

        if should_generate_summaries() and summary:
            attention_image_summary(weights)

        outputs = {"weights": weights, "outputs": outputs}

        if state is not None:
            outputs["state"] = next_state

        return outputs
Example #11
0
def multihead_attention(queries,
                        memories,
                        bias,
                        num_heads,
                        key_size,
                        value_size,
                        output_size,
                        keep_prob=None,
                        output=True,
                        state=None,
                        summary=True,
                        dtype=None,
                        scope=None):

    if key_size % num_heads != 0:
        raise ValueError("Key size (%d) must be divisible by the number of "
                         "attention heads (%d)." % (key_size, num_heads))

    if value_size % num_heads != 0:
        raise ValueError("Value size (%d) must be divisible by the number of "
                         "attention heads (%d)." % (value_size, num_heads))

    with tf.variable_scope(scope,
                           default_name="multihead_attention",
                           values=[queries, memories],
                           dtype=dtype):
        next_state = {}

        if memories is None:
            size = key_size * 2 + value_size
            combined = linear(queries, size, True, True, scope="qkv_transform")
            q, k, v = tf.split(combined, [key_size, key_size, value_size],
                               axis=-1)

            if state is not None:
                k = tf.concat([state["key"], k], axis=1)
                v = tf.concat([state["value"], v], axis=1)
                next_state["key"] = k
                next_state["value"] = v
        else:
            q = linear(queries, key_size, True, True, scope="q_transform")
            combined = linear(memories,
                              key_size + value_size,
                              True,
                              scope="kv_transform")
            k, v = tf.split(combined, [key_size, value_size], axis=-1)

        q = split_heads(q, num_heads)
        k = split_heads(k, num_heads)
        v = split_heads(v, num_heads)

        key_depth_per_head = key_size // num_heads
        q *= key_depth_per_head**-0.5

        results = multiplicative_attention(q, k, v, bias, keep_prob)

        weights = results["weights"]
        x = combine_heads(results["outputs"])

        if output:
            outputs = linear(x,
                             output_size,
                             True,
                             True,
                             scope="output_transform")
        else:
            outputs = x

        if should_generate_summaries() and summary:
            attention_image_summary(weights)

        outputs = {"weights": weights, "outputs": outputs}

        if state is not None:
            outputs["state"] = next_state

        return outputs
Example #12
0
def multihead_attention_v2n(queries,
                            memories,
                            bias,
                            w_x_inp,
                            num_heads,
                            key_size,
                            value_size,
                            output_size,
                            params,
                            keep_prob=None,
                            output=True,
                            dtype=None,
                            scope=None):
    """ Multi-head scaled-dot-product attention with input/output
        transformations.

    :param queries: A tensor with shape [batch, length_q, depth_q] if
    :param memories: A tensor with shape [batch, length_m, depth_m]
    :param bias: A tensor (see attention_bias)
    :param num_heads: An integer dividing key_size and value_size
    :param key_size: An integer
    :param value_size: An integer
    :param output_size: An integer
    :param keep_prob: A floating point number in (0, 1]
    :param output: Whether to use output transformation
    :param dtype: An optional instance of tf.DType
    :param scope: An optional string


    :returns: A dict with the following keys:
        weights: A tensor with shape [batch, heads, length_q, length_v]
        outputs: A tensor with shape [batch, length_q, depth_v]
        weight_ratio: [batch. length_q, d, length_v, d]

        w_x_inp: [batch, len_src, len_src, d] or [batch, len_trg, len_trg, d]
    """

    if key_size % num_heads != 0:
        raise ValueError("Key size (%d) must be divisible by the number of "
                         "attention heads (%d)." % (key_size, num_heads))

    if value_size % num_heads != 0:
        raise ValueError("Value size (%d) must be divisible by the number of "
                         "attention heads (%d)." % (value_size, num_heads))

    with tf.variable_scope(scope,
                           default_name="multihead_attention",
                           values=[queries, memories],
                           dtype=dtype):
        bs = tf.shape(w_x_inp)[0]
        len_q = tf.shape(queries)[1]
        len_src = tf.shape(w_x_inp)[1]
        dim = tf.shape(w_x_inp)[3]
        if memories is None:
            # self attention
            size = key_size * 2 + value_size
            combined_linear = linear_v2n(queries,
                                         size,
                                         True, [w_x_inp],
                                         params,
                                         True,
                                         scope="qkv_transform")
            combined = combined_linear["output"]
            q, k, v = tf.split(combined, [key_size, key_size, value_size],
                               axis=-1)
            w_x_combined = combined_linear["weight_ratios"][0]
            w_x_q, w_x_k, w_x_v = tf.split(w_x_combined,
                                           [key_size, key_size, value_size],
                                           axis=-1)
        else:
            q = nn.linear(queries,
                          key_size,
                          True,
                          params,
                          True,
                          scope="q_transform")
            combined_linear = linear_v2n(memories,
                                         key_size + value_size,
                                         True, [w_x_inp],
                                         params,
                                         True,
                                         scope="kv_transform")
            combined = combined_linear["output"]
            k, v = tf.split(combined, [key_size, value_size], axis=-1)
            w_x_combined = combined_linear["weight_ratios"][0]
            w_x_k, w_x_v = tf.split(w_x_combined, [key_size, value_size],
                                    axis=-1)

        # split heads
        q = attention.split_heads(q, num_heads)
        k = attention.split_heads(k, num_heads)
        v = attention.split_heads(v, num_heads)
        w_x_v = split_heads(w_x_v, num_heads)

        # scale query
        key_depth_per_head = key_size // num_heads
        q *= key_depth_per_head**-0.5

        # attention
        results = attention.multiplicative_attention(q, k, v, bias, keep_prob)

        # combine heads
        weights = results["weights"]
        x = attention.combine_heads(results["outputs"])

        w_x_v = tf.transpose(w_x_v, [0, 1, 3, 2, 4])
        w_x_v = tf.reshape(w_x_v, [bs, num_heads, tf.shape(w_x_v)[2], -1])
        w_x_att = tf.matmul(weights, w_x_v)
        w_x_att = tf.reshape(
            w_x_att, [bs, num_heads, len_q, len_src, key_depth_per_head])
        w_x_att = tf.transpose(w_x_att, [0, 1, 3, 2, 4])
        w_x_att = combine_heads_v2n(w_x_att)

        if output:
            outputs_linear = linear_v2n(x,
                                        output_size,
                                        True, [w_x_att],
                                        params,
                                        True,
                                        scope="output_transform")
            outputs = outputs_linear["output"]
            w_x_out = outputs_linear["weight_ratios"][0]
        else:
            outputs = x
            w_x_out = w_x_att

        return {
            "weights": weights,
            "outputs": outputs,
            "weight_ratio": w_x_out
        }
Example #13
0
def transformer_decoder(inputs, memory, bias, mem_bias, mt_bias, params, state=None,
                        dtype=None, scope=None, scaled_bias=None):
    with tf.variable_scope(scope, default_name="decoder", dtype=dtype,
                           values=[inputs, memory, bias, mem_bias, mt_bias]):
        x = inputs
        next_state = {}
        for layer in range(params.num_decoder_layers):
            layer_name = "layer_%d" % layer
            with tf.variable_scope(layer_name):
                layer_state = state[layer_name] if state is not None else None

                with tf.variable_scope("self_attention"):
                    y = layers.attention.multihead_attention(
                        _layer_process(x, params.layer_preprocess),
                        None,
                        bias,
                        params.num_heads,
                        params.attention_key_channels or params.hidden_size,
                        params.attention_value_channels or params.hidden_size,
                        params.hidden_size,
                        1.0 - params.attention_dropout,
                        state=layer_state
                    )

                    if layer_state is not None:
                        next_state[layer_name] = y["state"]

                    y = y["outputs"]
                    x = _residual_fn(x, y, 1.0 - params.residual_dropout)
                    x = _layer_process(x, params.layer_postprocess)

                with tf.variable_scope("encdec_attention"):
                    y = layers.attention.multihead_attention(
                        _layer_process(x, params.layer_preprocess),
                        memory,
                        mem_bias,
                        params.num_heads,
                        params.attention_key_channels or params.hidden_size,
                        params.attention_value_channels or params.hidden_size,
                        params.hidden_size,
                        1.0 - params.attention_dropout,
                        scaled_bias=scaled_bias,
                    )
                    y = y["outputs"]
                    x = _residual_fn(x, y, 1.0 - params.residual_dropout)
                    x = _layer_process(x, params.layer_postprocess)

                with tf.variable_scope("feed_forward"):
                    y = _ffn_layer(
                        _layer_process(x, params.layer_preprocess),
                        params.filter_size,
                        params.hidden_size,
                        1.0 - params.relu_dropout,
                    )
                    x = _residual_fn(x, y, 1.0 - params.residual_dropout)
                    x = _layer_process(x, params.layer_postprocess)

        with tf.variable_scope("copy_net"):
            z = x
            with tf.variable_scope("encdec_attention"):
                y = layers.attention.multihead_attention(
                    _layer_process(z, params.layer_preprocess),
                    memory,
                    mt_bias,
                    1, #params.num_heads,
                    params.attention_key_channels or params.hidden_size,
                    params.attention_value_channels or params.hidden_size,
                    params.hidden_size,
                    1.0 - params.attention_dropout,
                    scaled_bias=scaled_bias,
                )
                att = y["weights"]  # [bs, 1, lq, lk]
                y = y["outputs"]
                z = _residual_fn(z, y, 1.0 - params.residual_dropout)
                z = _layer_process(z, params.layer_postprocess)

            with tf.variable_scope("feed_forward"):
                y = _ffn_layer(
                    _layer_process(z, params.layer_preprocess),
                    params.filter_size,
                    params.hidden_size,
                    1.0 - params.relu_dropout,
                )
                z = _residual_fn(z, y, 1.0 - params.residual_dropout)
                z = _layer_process(z, params.layer_postprocess)

        outputs = _layer_process(x, params.layer_preprocess)

        z = _layer_process(z, params.layer_preprocess)
        z = linear(z, 1, True, True, scope="copy_ratio_w")
        z = tf.sigmoid(z)  # [bs, lq, 1]

        att = tf.squeeze(att, axis=1) # [bs, lq, lk]

        if state is not None:
            return outputs, next_state, att, z

        return outputs, att, z
Example #14
0
def main(args):
    eval_steps = args.eval_steps
    tf.logging.set_verbosity(tf.logging.DEBUG)
    # Load configs
    model_cls_list = [models.get_model(model) for model in args.models]
    params_list = [default_parameters() for _ in range(len(model_cls_list))]
    params_list = [
        merge_parameters(params, model_cls.get_parameters())
        for params, model_cls in zip(params_list, model_cls_list)
    ]
    params_list = [
        import_params(args.checkpoints[i], args.models[i], params_list[i])
        for i in range(len(args.checkpoints))
    ]
    params_list = [
        override_parameters(params_list[i], args)
        for i in range(len(model_cls_list))
    ]

    # Build Graph
    with tf.Graph().as_default():
        model_var_lists = []

        # Load checkpoints
        for i, checkpoint in enumerate(args.checkpoints):
            tf.logging.info("Loading %s" % checkpoint)
            var_list = tf.train.list_variables(checkpoint)
            values = {}
            reader = tf.train.load_checkpoint(checkpoint)

            for (name, shape) in var_list:
                if not name.startswith(model_cls_list[i].get_name()):
                    continue

                if name.find("losses_avg") >= 0:
                    continue

                tensor = reader.get_tensor(name)
                values[name] = tensor

            model_var_lists.append(values)

        # Build models
        model_fns = []

        for i in range(len(args.checkpoints)):
            name = model_cls_list[i].get_name()
            model = model_cls_list[i](params_list[i], name + "_%d" % i)
            model_fn = model.get_inference_func()
            model_fns.append(model_fn)

        params = params_list[0]
        # Read input file
        #features = dataset.get_inference_input(args.input, params)
        #features_eval = dataset.get_inference_input(args.eval, params)
        #features_test = dataset.get_inference_input(args.test, params)

        features_train = dataset.get_inference_input(args.input, params, False,
                                                     True)
        features_eval = dataset.get_inference_input(args.eval, params, True,
                                                    False)
        features_test = dataset.get_inference_input(args.test, params, True,
                                                    False)

        # Create placeholders
        placeholders = []

        for i in range(len(params.device_list)):
            placeholders.append({
                "source":
                tf.placeholder(tf.int32, [None, None], "source_%d" % i),
                "source_length":
                tf.placeholder(tf.int32, [None], "source_length_%d" % i),
                "target":
                tf.placeholder(tf.int32, [None, 2], "target_%d" % i)
            })

        # A list of outputs
        predictions = parallel.data_parallelism(
            params.device_list,
            lambda f: inference.create_inference_graph(model_fns, f, params),
            placeholders)

        # Create assign ops
        assign_ops = []

        all_var_list = tf.trainable_variables()

        for i in range(len(args.checkpoints)):
            un_init_var_list = []
            name = model_cls_list[i].get_name()

            for v in all_var_list:
                if v.name.startswith(name + "_%d" % i):
                    un_init_var_list.append(v)

            ops = set_variables(un_init_var_list, model_var_lists[i],
                                name + "_%d" % i)
            assign_ops.extend(ops)

        assign_op = tf.group(*assign_ops)
        results = []

        tf_x = tf.placeholder(tf.float32, [None, None, 512])
        tf_y = tf.placeholder(tf.int32, [None, 2])
        tf_x_len = tf.placeholder(tf.int32, [None])

        src_mask = -1e9 * (1.0 - tf.sequence_mask(
            tf_x_len, maxlen=tf.shape(predictions[0])[1], dtype=tf.float32))
        with tf.variable_scope("my_metric"):
            #q,k,v = tf.split(linear(tf_x, 3*512, True, True, scope="logit_transform"), [512, 512,512],axis=-1)
            q, k, v = tf.split(nn.linear(predictions[0],
                                         3 * 512,
                                         True,
                                         True,
                                         scope="logit_transform"),
                               [512, 512, 512],
                               axis=-1)
            q = nn.linear(
                tf.nn.tanh(q), 1, True, True,
                scope="logit_transform2")[:, :, 0] + src_mask
            # label smoothing
            ce1 = nn.smoothed_softmax_cross_entropy_with_logits(
                logits=q,
                labels=tf_y[:, :1],
                #smoothing=params.label_smoothing,
                smoothing=False,
                normalize=True)
            w1 = tf.nn.softmax(q)[:, None, :]
            #k = nn.linear(tf.nn.tanh(tf.matmul(w1,v)+k),1,True,True,scope="logit_transform3")[:,:,0]+src_mask
            k = tf.matmul(k,
                          tf.matmul(w1, v) *
                          (512**-0.5), False, True)[:, :, 0] + src_mask
            # label smoothing
            ce2 = nn.smoothed_softmax_cross_entropy_with_logits(
                logits=k,
                labels=tf_y[:, 1:],
                #smoothing=params.label_smoothing,
                smoothing=False,
                normalize=True)
            w2 = tf.nn.softmax(k)[:, None, :]
            weights = tf.concat([w1, w2], axis=1)
        loss = tf.reduce_mean(ce1 + ce2)

        #tf_x = tf.placeholder(tf.float32, [None, 512])
        #tf_y = tf.placeholder(tf.int32, [None])

        #l1 = tf.layers.dense(tf.squeeze(predictions[0], axis=-2), 64, tf.nn.sigmoid)
        #output = tf.layers.dense(l1, int(args.softmax_size))

        #loss = tf.losses.sparse_softmax_cross_entropy(labels=tf_y, logits=output)
        o1 = tf.argmax(w1, axis=-1)
        o2 = tf.argmax(w2, axis=-1)
        a1, a1_update = tf.metrics.accuracy(labels=tf.squeeze(tf_y[:, 0]),
                                            predictions=tf.argmax(w1, axis=-1),
                                            name='a1')
        a2, a2_update = tf.metrics.accuracy(labels=tf.squeeze(tf_y[:, 1]),
                                            predictions=tf.argmax(w2, axis=-1),
                                            name='a2')
        accuracy, accuracy_update = tf.metrics.accuracy(
            labels=tf.squeeze(tf_y),
            predictions=tf.argmax(weights, axis=-1),
            name='a_all')

        running_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                         scope="my_metric")
        #running_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="my_metric")
        running_vars_initializer = tf.variables_initializer(
            var_list=running_vars)

        #variables_to_train = tf.trainable_variables()
        #print (len(variables_to_train), (variables_to_train[0]), variables_to_train[1])
        #variables_to_train.remove(variables_to_train[0])
        #variables_to_train.remove(variables_to_train[0])
        #print (len(variables_to_train))
        variables_to_train = [
            v for v in tf.trainable_variables()
            if v.name.startswith("my_metric")
        ]

        optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
        train_op = optimizer.minimize(loss, var_list=variables_to_train)
        #train_op = optimizer.minimize(loss, var_list=running_vars)

        # Create session
        with tf.Session(config=session_config(params)) as sess:
            init_op = tf.group(tf.global_variables_initializer(),
                               tf.local_variables_initializer())
            sess.run(init_op)
            # Restore variables
            sess.run(assign_op)
            sess.run(tf.tables_initializer())

            current_step = 0

            best_validate_acc = 0
            last_test_acc = 0

            train_x_set = []
            train_y_set = []
            valid_x_set = []
            valid_y_set = []
            test_x_set = []
            test_y_set = []
            train_x_len_set = []
            valid_x_len_set = []
            test_x_len_set = []

            while current_step < eval_steps:
                print('=======current step ' + str(current_step))
                batch_num = 0
                while True:
                    try:
                        feats = sess.run(features_train)
                        op, feed_dict = shard_features(feats, placeholders,
                                                       predictions)
                        #x = (np.squeeze(sess.run(predictions, feed_dict=feed_dict), axis = -2))
                        y = feed_dict.values()[2]
                        x_len = feed_dict.values()[1]

                        feed_dict.update({tf_y: y})
                        feed_dict.update({tf_x_len: x_len})

                        los, __, pred = sess.run([loss, train_op, weights],
                                                 feed_dict=feed_dict)
                        print("current_step", current_step, "batch_num",
                              batch_num, "loss", los)

                        batch_num += 1
                        if batch_num % 100 == 0:

                            # eval
                            b_total = 0
                            a_total = 0
                            a1_total = 0
                            a2_total = 0
                            validate_acc = 0
                            batch_num_eval = 0

                            while True:
                                try:
                                    feats_eval = sess.run(features_eval)
                                    op, feed_dict_eval = shard_features(
                                        feats_eval, placeholders, predictions)
                                    #x = (np.squeeze(sess.run(predictions, feed_dict=feed_dict), axis = -2))
                                    y = feed_dict_eval.values()[2]
                                    x_len = feed_dict_eval.values()[1]
                                    feed_dict_eval.update({tf_y: y})
                                    feed_dict_eval.update({tf_x_len: x_len})

                                    sess.run(running_vars_initializer)
                                    acc = 0
                                    #acc, pred = sess.run([accuracy, output], feed_dict = {tf_x : x, tf_y : y})
                                    sess.run([
                                        a1_update, a2_update, accuracy_update,
                                        weights
                                    ],
                                             feed_dict=feed_dict_eval)
                                    acc1, acc2, acc = sess.run(
                                        [a1, a2, accuracy])
                                    batch_size = len(y)
                                    #print(acc)
                                    a1_total += round(batch_size * acc1)
                                    a2_total += round(batch_size * acc2)
                                    a_total += round(batch_size * acc)
                                    b_total += batch_size
                                    batch_num_eval += 1

                                    if batch_num_eval == 20:
                                        break

                                except tf.errors.OutOfRangeError:
                                    print("eval out of range")
                                    break
                            if b_total:
                                validate_acc = a_total / b_total
                                print("eval acc : " + str(validate_acc) +
                                      "( " + str(a1_total / b_total) + ", " +
                                      str(a2_total / b_total) + " )")
                            print("last test acc : " + str(last_test_acc))

                            if validate_acc > best_validate_acc:
                                best_validate_acc = validate_acc

                            # test
                            b_total = 0
                            a1_total = 0
                            a2_total = 0
                            a_total = 0
                            batch_num_test = 0
                            with open(args.output, "w") as outfile:
                                while True:
                                    try:
                                        feats_test = sess.run(features_test)
                                        op, feed_dict_test = shard_features(
                                            feats_test, placeholders,
                                            predictions)

                                        #x = (np.squeeze(sess.run(predictions, feed_dict=feed_dict), axis = -2))
                                        y = feed_dict_test.values()[2]
                                        x_len = feed_dict_test.values()[1]
                                        feed_dict_test.update({tf_y: y})
                                        feed_dict_test.update(
                                            {tf_x_len: x_len})

                                        sess.run(running_vars_initializer)
                                        acc = 0
                                        #acc, pred = sess.run([accuracy, output], feed_dict = {tf_x : x, tf_y : y})
                                        __, __, __, out1, out2 = sess.run(
                                            [
                                                a1_update, a2_update,
                                                accuracy_update, o1, o2
                                            ],
                                            feed_dict=feed_dict_test)
                                        acc1, acc2, acc = sess.run(
                                            [a1, a2, accuracy])

                                        batch_size = len(y)
                                        a_total += round(batch_size * acc)
                                        a1_total += round(batch_size * acc1)
                                        a2_total += round(batch_size * acc2)
                                        b_total += batch_size
                                        batch_num_test += 1
                                        for pred1, pred2 in zip(out1, out2):
                                            outfile.write("%s " % pred1[0])
                                            outfile.write("%s\n" % pred2[0])
                                        if batch_num_test == 20:
                                            break
                                    except tf.errors.OutOfRangeError:
                                        print("test out of range")
                                        break
                                if b_total:
                                    last_test_acc = a_total / b_total
                                    print("new test acc : " +
                                          str(last_test_acc) + "( " +
                                          str(a1_total / b_total) + ", " +
                                          str(a2_total / b_total) + " )")

                        if batch_num == 25000:
                            break
                    except tf.errors.OutOfRangeError:
                        print("train out of range")
                        break

                # eval


#                b_total = 0
#                a_total = 0
#                a1_total = 0
#                a2_total = 0
#                validate_acc = 0
#                batch_num = 0

#                while True:
#                    try:
#                        feats_eval = sess.run(features_eval)
#                        op, feed_dict = shard_features(feats_eval, placeholders, predictions)
#                        #x = (np.squeeze(sess.run(predictions, feed_dict=feed_dict), axis = -2))
#                        y =  feed_dict.values()[2]
#                        x_len =  feed_dict.values()[1]
#                        feed_dict.update({tf_y:y})
#                        feed_dict.update({tf_x_len:x_len})

#                        sess.run(running_vars_initializer)
#                        acc = 0
#acc, pred = sess.run([accuracy, output], feed_dict = {tf_x : x, tf_y : y})
#                        sess.run([a1_update, a2_update, accuracy_update, weights], feed_dict = feed_dict)
#                        acc1,acc2,acc = sess.run([a1,a2,accuracy])
#                        batch_size = len(y)
#print(acc)
#                        a1_total += round(batch_size*acc1)
#                        a2_total += round(batch_size*acc2)
#                        a_total += round(batch_size*acc)
#                        b_total += batch_size
#                        batch_num += 1

#                        if batch_num == 10:
#                            break

#                    except tf.errors.OutOfRangeError:
#                        print ("eval out of range")
#                        break

#                validate_acc = a_total/b_total
#                print("eval acc : "  + str(validate_acc) + "( "+str(a1_total/b_total)+ ", "+ str(a2_total/b_total) + " )")
#                print("last test acc : " + str(last_test_acc))

#                if validate_acc > best_validate_acc:
#                    best_validate_acc = validate_acc

# test
#                    b_total = 0
#                    a1_total = 0
#                    a2_total = 0
#                    a_total = 0
#                    batch_num = 0

#                    while True:
#                        try:
#                            feats_test = sess.run(features_test)
#                            op, feed_dict = shard_features(feats_test, placeholders,
#                                                             predictions)

#x = (np.squeeze(sess.run(predictions, feed_dict=feed_dict), axis = -2))
#                            y =  feed_dict.values()[2]
#                            x_len =  feed_dict.values()[1]
#                            feed_dict.update({tf_y:y})
#                            feed_dict.update({tf_x_len:x_len})

#                            sess.run(running_vars_initializer)
#                            acc = 0
#acc, pred = sess.run([accuracy, output], feed_dict = {tf_x : x, tf_y : y})
#                            sess.run([a1_update,a2_update,accuracy_update, weights], feed_dict = feed_dict)
#                            acc1,acc2,acc = sess.run([a1,a2,accuracy])

#                            batch_size = len(y)
#                            a_total += round(batch_size*acc)
#                            a1_total += round(batch_size*acc1)
#                            a2_total += round(batch_size*acc2)
#                            b_total += batch_size
#                            batch_num += 1

#                            if batch_num==10:
#                                break
#                        except tf.errors.OutOfRangeError:
#                            print ("test out of range")
#                            break
#                    last_test_acc = a_total/b_total
#                    print("new test acc : " + str(last_test_acc)+ "( "+str(a1_total/b_total)+ ", "+ str(a2_total/b_total) + " )")

                current_step += 1
                print("")
        print("Final test acc " + str(last_test_acc))

        return