示例#1
0
def batch_dot(left, right):
    # wraps mxnet.symbol.batch_dot
    left_symbol = symbol.Variable('left')
    right_symbol = symbol.Variable('right')
    result_symbol = symbol.batch_dot(left_symbol, right_symbol)
    shapes = {'left': left.shape, 'right': right.shape}
    kwargs = {'left': left, 'right': right}
    return Function(result_symbol, shapes)(**kwargs)
示例#2
0
文件: facility.py 项目: zzlab/DL-Alg
def batch_dot(left, right):
    # assert left.shape[0] == right.shape[0] and left.shape[2] == right.shape[1]
    left_symbol = symbol.Variable('left')
    right_symbol = symbol.Variable('right')
    result_symbol = symbol.batch_dot(left_symbol, right_symbol)
    shapes = {'left': left.shape, 'right': right.shape}
    kwargs = {'left': left, 'right': right}
    return Function(result_symbol, shapes)(**kwargs)
    def __call__(self, inputs, states):
        # inputs: (batch_size, decoder_num_hidden)
        # for dot attention decoder_num_hidden must equal encoder_num_hidden
        if len(states) > 1:
            states = [symbol.concat(*states, dim=1)]

        # source: (batch_size, seq_len, encoder_num_hidden)
        source = states[0]
        # (batch_size, decoder_num_hidden, 1)
        inputs = symbol.expand_dims(inputs, axis=2)
        # (batch_size, seq_len, 1)
        scores = symbol.batch_dot(source, inputs)
        # (batch_size, encoder_num_hidden)
        return _attention_pooling(source, scores), states
def _attention_pooling(source, scores):
    # source: (batch_size, seq_len, encoder_num_hidden)
    # scores: (batch_size, seq_len, 1)
    probs = symbol.softmax(scores, axis=1)
    output = symbol.batch_dot(source, probs, transpose_a=True)
    return symbol.reshape(output, shape=(0, 0))