def batch_dot(left, right): # wraps mxnet.symbol.batch_dot left_symbol = symbol.Variable('left') right_symbol = symbol.Variable('right') result_symbol = symbol.batch_dot(left_symbol, right_symbol) shapes = {'left': left.shape, 'right': right.shape} kwargs = {'left': left, 'right': right} return Function(result_symbol, shapes)(**kwargs)
def batch_dot(left, right): # assert left.shape[0] == right.shape[0] and left.shape[2] == right.shape[1] left_symbol = symbol.Variable('left') right_symbol = symbol.Variable('right') result_symbol = symbol.batch_dot(left_symbol, right_symbol) shapes = {'left': left.shape, 'right': right.shape} kwargs = {'left': left, 'right': right} return Function(result_symbol, shapes)(**kwargs)
def __call__(self, inputs, states): # inputs: (batch_size, decoder_num_hidden) # for dot attention decoder_num_hidden must equal encoder_num_hidden if len(states) > 1: states = [symbol.concat(*states, dim=1)] # source: (batch_size, seq_len, encoder_num_hidden) source = states[0] # (batch_size, decoder_num_hidden, 1) inputs = symbol.expand_dims(inputs, axis=2) # (batch_size, seq_len, 1) scores = symbol.batch_dot(source, inputs) # (batch_size, encoder_num_hidden) return _attention_pooling(source, scores), states
def _attention_pooling(source, scores): # source: (batch_size, seq_len, encoder_num_hidden) # scores: (batch_size, seq_len, 1) probs = symbol.softmax(scores, axis=1) output = symbol.batch_dot(source, probs, transpose_a=True) return symbol.reshape(output, shape=(0, 0))