def _attention(self, x, h_prev_summary, c_tape, h_tape): """ :param x: batch_size * input_size :param h_prev_summary:batch_size * cell_size :param c_tape: batch_size * memory_size * cell_size :param h_tape: batch_size * memory_size * cell_size :return: """ input_size = x.get_shape().with_rank(2)[1] with vs.variable_scope("Attention"): # mask out empty slots mask = tf.sign(tf.reduce_max(tf.abs(h_tape), reduction_indices=2)) # construct query for attention concat_w = rnn_cell._get_concat_variable("query_w", [input_size.value+self._num_units, self._num_units], x.dtype, 1) b = vs.get_variable("query_bias", shape=[self._num_units], initializer=array_ops.zeros_initializer, dtype=x.dtype) query = tf.nn.bias_add(math_ops.matmul(array_ops.concat(1, [x, h_prev_summary]), concat_w), b) query = array_ops.reshape(query, [-1, 1, 1, self._num_units]) # get the weights for attention k = vs.get_variable("AttnW", [1, 1, self._num_units, self._num_units]) v = vs.get_variable("AttnV", [self._num_units]) hidden = array_ops.reshape(h_tape, [-1, self._attn_length, 1, self._num_units]) memory = array_ops.reshape(c_tape, [-1, self._attn_length, 1, self._num_units]) hidden_features = tf.nn.conv2d(hidden, k, [1, 1, 1, 1], "SAME") s = tf.reduce_sum(v * tf.tanh(hidden_features + query), [2, 3]) a = tf.nn.softmax(s) * mask a = a / (tf.reduce_sum(a, reduction_indices=1, keep_dims=True) + 1e-12) h_summary = tf.reduce_sum(array_ops.reshape(a, [-1, self._attn_length, 1, 1]) * hidden, [1, 2]) c_summary = tf.reduce_sum(array_ops.reshape(a, [-1, self._attn_length, 1, 1]) * memory, [1, 2]) return a, c_summary, h_summary
def _attention(self, x, h_prev_summary, h_tape): """ :param x: batch_size * input_size :param h_prev_summary:batch_size * cell_size :param h_tape: batch_size * memory_size * cell_size :return: weighted sum h, and trucated h_tape """ input_size = x.get_shape().with_rank(2)[1] with vs.variable_scope("Attention"): # construct query for attention concat_w = rnn_cell._get_concat_variable("query_w", [input_size.value+self._num_units, self._attn_size], x.dtype, 1) b = vs.get_variable("query_bias", shape=[self._attn_size], initializer=array_ops.zeros_initializer, dtype=x.dtype) query = tf.nn.bias_add(math_ops.matmul(array_ops.concat(1, [x, h_prev_summary]), concat_w), b) query = array_ops.reshape(query, [-1, 1, 1, self._attn_size]) # get temporal feature mask = tf.sign(tf.reduce_max(tf.abs(h_tape), reduction_indices=2, keep_dims=True)) t_ids = tf.ones_like(mask) * self._attn_indexes mask = tf.squeeze(mask) # get the weights for attention k = vs.get_variable("AttnW", [1, 1, self._num_units+1, self._attn_size]) v = vs.get_variable("AttnV", [self._attn_size]) hidden = tf.concat(2, [h_tape, t_ids]) hidden = array_ops.reshape(hidden, [-1, self._attn_length, 1, self._num_units+1]) hidden_features = tf.nn.conv2d(hidden, k, [1, 1, 1, 1], "SAME") # compute attention point s = tf.reduce_sum(v * tf.tanh(hidden_features + query), [2, 3]) a = tf.nn.softmax(s)* mask a = a / (tf.reduce_sum(a, reduction_indices=1, keep_dims=True) + 1e-12) h_summary = tf.reduce_sum(array_ops.reshape(a, [-1, self._attn_length, 1]) * h_tape, reduction_indices=1) return h_summary
def __call__(self, inputs, state, scope=None): """Run one step of LSTM. Args: inputs: input Tensor, 2D, batch x num_units. state: state Tensor, 2D, batch x state_size. scope: VariableScope for the created subgraph; defaults to "LSTMCell". Returns: A tuple containing: - A 2D, batch x output_dim, Tensor representing the output of the LSTM after reading "inputs" when previous state was "state". Here output_dim is: num_proj if num_proj was set, num_units otherwise. - A 2D, batch x state_size, Tensor representing the new state of LSTM after reading "inputs" when previous state was "state". Raises: ValueError: if an input_size was specified and the provided inputs have a different dimension. """ num_proj = self._num_units if self._num_proj is None else self._num_proj c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) dtype = inputs.dtype actual_input_size = inputs.get_shape().as_list()[1] if self._input_size and self._input_size != actual_input_size: raise ValueError( "Actual input size not same as specified: %d vs %d." % (actual_input_size, self._input_size)) scope_name = scope or type(self).__name__ with vs.variable_scope(scope_name, initializer=self._initializer): # "LSTMCell" if not self._bn: concat_w = _get_concat_variable( "W", [actual_input_size + num_proj, 4 * self._num_units], dtype, self._num_unit_shards) else: concat_w_i = _get_concat_variable( "W_i", [actual_input_size, 4 * self._num_units], dtype, self._num_unit_shards) concat_w_r = _get_concat_variable( "W_r", [num_proj, 4 * self._num_units], dtype, self._num_unit_shards) b = vs.get_variable("B", shape=[4 * self._num_units], initializer=array_ops.zeros_initializer, dtype=dtype) # i = input_gate, j = new_input, f = forget_gate, o = output_gate if not self._bn: cell_inputs = array_ops.concat(1, [inputs, m_prev]) lstm_matrix = nn_ops.bias_add( math_ops.matmul(cell_inputs, concat_w), b) else: lstm_matrix_i = batch_norm(math_ops.matmul(inputs, concat_w_i), self._deterministic, shift=False, scope=scope_name + 'bn_i') if self._bn > 1: lstm_matrix_r = batch_norm(math_ops.matmul( m_prev, concat_w_r), self._deterministic, shift=False, scope=scope_name + 'bn_r') else: lstm_matrix_r = math_ops.matmul(m_prev, concat_w_r) lstm_matrix = nn_ops.bias_add( math_ops.add(lstm_matrix_i, lstm_matrix_r), b) i, j, f, o = array_ops.split(1, 4, lstm_matrix) # Diagonal connections if self._use_peepholes: w_f_diag = vs.get_variable("W_F_diag", shape=[self._num_units], dtype=dtype) w_i_diag = vs.get_variable("W_I_diag", shape=[self._num_units], dtype=dtype) w_o_diag = vs.get_variable("W_O_diag", shape=[self._num_units], dtype=dtype) if self._use_peepholes: c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev + sigmoid(i + w_i_diag * c_prev) * tanh(j)) else: c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * tanh(j)) if self._cell_clip is not None: c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) if self._use_peepholes: if self._bn > 2: m = sigmoid(o + w_o_diag * c) * tanh( batch_norm( c, self._deterministic, scope=scope_name + 'bn_m')) else: m = sigmoid(o + w_o_diag * c) * tanh(c) else: if self._bn > 2: m = sigmoid(o) * tanh( batch_norm( c, self._deterministic, scope=scope_name + 'bn_m')) else: m = sigmoid(o) * tanh(c) if self._num_proj is not None: concat_w_proj = _get_concat_variable( "W_P", [self._num_units, self._num_proj], dtype, self._num_proj_shards) m = math_ops.matmul(m, concat_w_proj) if not self._return_gate: return m, array_ops.concat(1, [c, m]) else: return m, array_ops.concat(1, [c, m]), (i, j, f, o)
def __call__(self, inputs, state, scope=None): """Run one step of LSTM. Args: inputs: input Tensor, 2D, batch x num_units. state: if `state_is_tuple` is False, this must be a state Tensor, `2-D, batch x state_size`. If `state_is_tuple` is True, this must be a tuple of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`. scope: VariableScope for the created subgraph; defaults to "LSTMCell". Returns: A tuple containing: - A `2-D, [batch x output_dim]`, Tensor representing the output of the LSTM after reading `inputs` when previous state was `state`. Here output_dim is: num_proj if num_proj was set, num_units otherwise. - Tensor(s) representing the new state of LSTM after reading `inputs` when the previous state was `state`. Same type and shape(s) as `state`. Raises: ValueError: If input size cannot be inferred from inputs via static shape inference. """ num_proj = self._num_units if self._num_proj is None else self._num_proj (c_prev, m_prev) = state dtype = inputs.dtype input_size = inputs.get_shape().with_rank(2)[1] if input_size.value is None: raise ValueError("Could not infer input size from inputs.get_shape()[-1]") scope_name = scope or type(self).__name__ with vs.variable_scope(scope_name, initializer=self._initializer): # "LSTMCell" if self._bn: concat_w_i = _get_concat_variable( "W_i", [input_size.value, 4 * self._num_units], dtype, self._num_unit_shards) concat_w_r = _get_concat_variable( "W_r", [num_proj, 4 * self._num_units], dtype, self._num_unit_shards) b = vs.get_variable( "B", shape=[4 * self._num_units], initializer=array_ops.zeros_initializer, dtype=dtype) else: concat_w = _get_concat_variable( "W", [input_size.value + num_proj, 4 * self._num_units], dtype, self._num_unit_shards) b = vs.get_variable( "B", shape=[4 * self._num_units], initializer=array_ops.zeros_initializer, dtype=dtype) # i = input_gate, j = new_input, f = forget_gate, o = output_gate if self._bn: lstm_matrix_i = batch_norm(math_ops.matmul(inputs, concat_w_i), self._deterministic, shift=False, scope=scope_name+'bn_i') if self._bn > 1: lstm_matrix_r = batch_norm(math_ops.matmul(m_prev, concat_w_r), self._deterministic, shift=False, scope=scope_name+'bn_r') else: lstm_matrix_r = math_ops.matmul(m_prev, concat_w_r) lstm_matrix = nn_ops.bias_add(math_ops.add(lstm_matrix_i, lstm_matrix_r), b) else: cell_inputs = array_ops.concat(1, [inputs, m_prev]) lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b) i, j, f, o = array_ops.split(1, 4, lstm_matrix) # Diagonal connections if self._use_peepholes: w_f_diag = vs.get_variable( "W_F_diag", shape=[self._num_units], dtype=dtype) w_i_diag = vs.get_variable( "W_I_diag", shape=[self._num_units], dtype=dtype) w_o_diag = vs.get_variable( "W_O_diag", shape=[self._num_units], dtype=dtype) if self._use_peepholes: c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev + sigmoid(i + w_i_diag * c_prev) * self._activation(j)) else: c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * self._activation(j)) if self._cell_clip is not None: # pylint: disable=invalid-unary-operand-type c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) # pylint: enable=invalid-unary-operand-type if self._use_peepholes: if self._bn > 2: m = sigmoid(o + w_o_diag * c) * self._activation(batch_norm(c, self._deterministic, scope=scope_name+'bn_m')) else: m = sigmoid(o + w_o_diag * c) * self._activation(c) else: if self._bn > 2: m = sigmoid(o) * self._activation(batch_norm(c, self._deterministic, scope=scope_name+'bn_m')) else: m = sigmoid(o) * self._activation(c) if self._num_proj is not None: concat_w_proj = _get_concat_variable( "W_P", [self._num_units, self._num_proj], dtype, self._num_proj_shards) m = math_ops.matmul(m, concat_w_proj) new_state = LSTMStateTuple(c, m) if not self._return_gate: return m, new_state else: return m, new_state, (i, j, f, o)
def __call__(self, inputs, state, scope=None): """Run one step of LSTM. Args: inputs: input Tensor, 2D, batch x num_units. state: if `state_is_tuple` is False, this must be a state Tensor, `2-D, batch x state_size`. If `state_is_tuple` is True, this must be a tuple of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`. scope: VariableScope for the created subgraph; defaults to "LSTMCell". Returns: A tuple containing: - A `2-D, [batch x output_dim]`, Tensor representing the output of the LSTM after reading `inputs` when previous state was `state`. Here output_dim is: num_proj if num_proj was set, num_units otherwise. - Tensor(s) representing the new state of LSTM after reading `inputs` when the previous state was `state`. Same type and shape(s) as `state`. Raises: ValueError: If input size cannot be inferred from inputs via static shape inference. """ num_proj = self._num_units if self._num_proj is None else self._num_proj (c_prev, m_prev) = state dtype = inputs.dtype input_size = inputs.get_shape().with_rank(2)[1] if input_size.value is None: raise ValueError( "Could not infer input size from inputs.get_shape()[-1]") scope_name = scope or type(self).__name__ with vs.variable_scope(scope_name, initializer=self._initializer): # "LSTMCell" if self._bn: concat_w_i = _get_concat_variable( "W_i", [input_size.value, 4 * self._num_units], dtype, self._num_unit_shards) concat_w_r = _get_concat_variable( "W_r", [num_proj, 4 * self._num_units], dtype, self._num_unit_shards) b = vs.get_variable("B", shape=[4 * self._num_units], initializer=array_ops.zeros_initializer, dtype=dtype) else: concat_w = _get_concat_variable( "W", [input_size.value + num_proj, 4 * self._num_units], dtype, self._num_unit_shards) b = vs.get_variable("B", shape=[4 * self._num_units], initializer=array_ops.zeros_initializer, dtype=dtype) # i = input_gate, j = new_input, f = forget_gate, o = output_gate if self._bn: lstm_matrix_i = batch_norm(math_ops.matmul(inputs, concat_w_i), self._deterministic, shift=False, scope=scope_name + 'bn_i') if self._bn > 1: lstm_matrix_r = batch_norm(math_ops.matmul( m_prev, concat_w_r), self._deterministic, shift=False, scope=scope_name + 'bn_r') else: lstm_matrix_r = math_ops.matmul(m_prev, concat_w_r) lstm_matrix = nn_ops.bias_add( math_ops.add(lstm_matrix_i, lstm_matrix_r), b) else: cell_inputs = array_ops.concat(1, [inputs, m_prev]) lstm_matrix = nn_ops.bias_add( math_ops.matmul(cell_inputs, concat_w), b) i, j, f, o = array_ops.split(1, 4, lstm_matrix) # Diagonal connections if self._use_peepholes: w_f_diag = vs.get_variable("W_F_diag", shape=[self._num_units], dtype=dtype) w_i_diag = vs.get_variable("W_I_diag", shape=[self._num_units], dtype=dtype) w_o_diag = vs.get_variable("W_O_diag", shape=[self._num_units], dtype=dtype) if self._use_peepholes: c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev + sigmoid(i + w_i_diag * c_prev) * self._activation(j)) else: c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * self._activation(j)) if self._cell_clip is not None: # pylint: disable=invalid-unary-operand-type c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) # pylint: enable=invalid-unary-operand-type if self._use_peepholes: if self._bn > 2: m = sigmoid(o + w_o_diag * c) * self._activation( batch_norm( c, self._deterministic, scope=scope_name + 'bn_m')) else: m = sigmoid(o + w_o_diag * c) * self._activation(c) else: if self._bn > 2: m = sigmoid(o) * self._activation( batch_norm( c, self._deterministic, scope=scope_name + 'bn_m')) else: m = sigmoid(o) * self._activation(c) if self._num_proj is not None: concat_w_proj = _get_concat_variable( "W_P", [self._num_units, self._num_proj], dtype, self._num_proj_shards) m = math_ops.matmul(m, concat_w_proj) new_state = LSTMStateTuple(c, m) if not self._return_gate: return m, new_state else: return m, new_state, (i, j, f, o)
def __call__(self, inputs, state, scope=None): """Run one step of MemoryLSTM. Args: inputs: input Tensor, 2D, batch x num_units. state: if `state_is_tuple` is False, this must be a state Tensor, `2-D, batch x state_size`. If `state_is_tuple` is True, this must be a tuple of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`. scope: VariableScope for the created subgraph; defaults to "LSTMCell". Returns: A tuple containing: - A `2-D, [batch x output_dim]`, Tensor representing the output of the LSTM after reading `inputs` when previous state was `state`. Here output_dim is: num_proj if num_proj was set, num_units otherwise. - Tensor(s) representing the new state of LSTM after reading `inputs` when the previous state was `state`. Same type and shape(s) as `state`. Raises: ValueError: If input size cannot be inferred from inputs via static shape inference. """ (a, h_prev_summary, c_tape_prev, h_tape_prev) = state dtype = inputs.dtype input_size = inputs.get_shape().with_rank(2)[1] if input_size.value is None: raise ValueError("Could not infer input size from inputs.get_shape()[-1]") with vs.variable_scope(scope or type(self).__name__, initializer=self._initializer): # "LSTMCell" concat_w = rnn_cell._get_concat_variable( "W", [input_size.value + self._num_units, 4 * self._num_units], dtype, 1) b = vs.get_variable("Bias", shape=[4 * self._num_units], initializer=array_ops.zeros_initializer, dtype=dtype) # reshape tape to 3D c_tape_prev = array_ops.reshape(c_tape_prev, [-1, self._attn_length, self._num_units]) h_tape_prev = array_ops.reshape(h_tape_prev, [-1, self._attn_length, self._num_units]) a, new_c_summary, new_h_summary = self._attention(inputs, h_prev_summary, c_tape_prev, h_tape_prev) # i = input_gate, j = new_input, f = forget_gate, o = output_gate cell_inputs = array_ops.concat(1, [inputs, new_h_summary]) lstm_matrix = tf.nn.bias_add(math_ops.matmul(cell_inputs, concat_w), b) i, j, f, o = array_ops.split(1, 4, lstm_matrix) # Diagonal connections if self._use_peepholes: w_f_diag = vs.get_variable( "W_F_diag", shape=[self._num_units], dtype=dtype) w_i_diag = vs.get_variable( "W_I_diag", shape=[self._num_units], dtype=dtype) w_o_diag = vs.get_variable( "W_O_diag", shape=[self._num_units], dtype=dtype) if self._use_peepholes: c = (sigmoid(f + self._forget_bias + w_f_diag * new_c_summary) * new_c_summary + sigmoid(i + w_i_diag * new_c_summary) * self._activation(j)) else: c = (sigmoid(f + self._forget_bias) * new_c_summary + sigmoid(i) * self._activation(j)) if self._cell_clip is not None: c = tf.clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) if self._use_peepholes: h = sigmoid(o + w_o_diag * c) * self._activation(c) else: h = sigmoid(o) * self._activation(c) # remove old value new_h_tape = array_ops.slice(h_tape_prev, [0, 1, 0], [-1, -1, -1]) new_c_tape = array_ops.slice(c_tape_prev, [0, 1, 0], [-1, -1, -1]) # append the new c and h to the tape new_c_tape = array_ops.concat(1, [new_c_tape, array_ops.expand_dims(c, 1)]) new_h_tape = array_ops.concat(1, [new_h_tape, array_ops.expand_dims(h, 1)]) # flatten the tape to 2D new_c_tape = array_ops.reshape(new_c_tape, [-1, self._attn_length * self._num_units]) new_h_tape = array_ops.reshape(new_h_tape, [-1, self._attn_length * self._num_units]) new_state = (a, new_h_summary, new_c_tape, new_h_tape) return h, new_state
def __call__(self, inputs, state, scope=None): """Run one step of LSTM. Args: inputs: input Tensor, 2D, batch x num_units. state: state Tensor, 2D, batch x state_size. scope: VariableScope for the created subgraph; defaults to "LSTMCell". Returns: A tuple containing: - A 2D, batch x output_dim, Tensor representing the output of the LSTM after reading "inputs" when previous state was "state". Here output_dim is: num_proj if num_proj was set, num_units otherwise. - A 2D, batch x state_size, Tensor representing the new state of LSTM after reading "inputs" when previous state was "state". Raises: ValueError: if an input_size was specified and the provided inputs have a different dimension. """ num_proj = self._num_units if self._num_proj is None else self._num_proj c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) dtype = inputs.dtype actual_input_size = inputs.get_shape().as_list()[1] if self._input_size and self._input_size != actual_input_size: raise ValueError("Actual input size not same as specified: %d vs %d." % (actual_input_size, self._input_size)) scope_name = scope or type(self).__name__ with vs.variable_scope(scope_name, initializer=self._initializer): # "LSTMCell" if not self._bn: concat_w = _get_concat_variable( "W", [actual_input_size + num_proj, 4 * self._num_units], dtype, self._num_unit_shards) else: concat_w_i = _get_concat_variable( "W_i", [actual_input_size, 4 * self._num_units], dtype, self._num_unit_shards) concat_w_r = _get_concat_variable( "W_r", [num_proj, 4 * self._num_units], dtype, self._num_unit_shards) b = vs.get_variable( "B", shape=[4 * self._num_units], initializer=array_ops.zeros_initializer, dtype=dtype) # i = input_gate, j = new_input, f = forget_gate, o = output_gate if not self._bn: cell_inputs = array_ops.concat(1, [inputs, m_prev]) lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b) else: lstm_matrix_i = batch_norm(math_ops.matmul(inputs, concat_w_i), self._deterministic, shift=False, scope=scope_name+'bn_i') if self._bn > 1: lstm_matrix_r = batch_norm(math_ops.matmul(m_prev, concat_w_r), self._deterministic, shift=False, scope=scope_name+'bn_r') else: lstm_matrix_r = math_ops.matmul(m_prev, concat_w_r) lstm_matrix = nn_ops.bias_add(math_ops.add(lstm_matrix_i, lstm_matrix_r), b) i, j, f, o = array_ops.split(1, 4, lstm_matrix) # Diagonal connections if self._use_peepholes: w_f_diag = vs.get_variable( "W_F_diag", shape=[self._num_units], dtype=dtype) w_i_diag = vs.get_variable( "W_I_diag", shape=[self._num_units], dtype=dtype) w_o_diag = vs.get_variable( "W_O_diag", shape=[self._num_units], dtype=dtype) if self._use_peepholes: c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev + sigmoid(i + w_i_diag * c_prev) * tanh(j)) else: c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * tanh(j)) if self._cell_clip is not None: c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) if self._use_peepholes: if self._bn > 2: m = sigmoid(o + w_o_diag * c) * tanh(batch_norm(c, self._deterministic, scope=scope_name+'bn_m')) else: m = sigmoid(o + w_o_diag * c) *tanh(c) else: if self._bn > 2: m = sigmoid(o) * tanh(batch_norm(c, self._deterministic, scope=scope_name+'bn_m')) else: m = sigmoid(o) * tanh(c) if self._num_proj is not None: concat_w_proj = _get_concat_variable( "W_P", [self._num_units, self._num_proj], dtype, self._num_proj_shards) m = math_ops.matmul(m, concat_w_proj) if not self._return_gate: return m, array_ops.concat(1, [c, m]) else: return m, array_ops.concat(1, [c, m]), (i, j, f, o)