def call(self, inputs, state): with vs.variable_scope("gates"): bias_ones = self._bias_initializer if self._bias_initializer is None: dtype = [a.dtype for a in [inputs, state]][0] bias_ones = init_ops.constant_initializer(1.0, dtype=dtype) value = _linear([inputs, state], 2 * self._hidden_size, True, bias_ones, aux.rum_ortho_initializer()) r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1) u = sigmoid(u) if self._use_layer_norm: concat = tf.concat([r, u], 1) concat = aux.layer_norm_all(concat, 2, self._hidden_size, "LN_r_u") r, u = tf.split(concat, 2, 1) with vs.variable_scope("candidate"): x_emb = _linear(inputs, self._hidden_size, True, self._bias_initializer, self._kernel_initializer) state_new = rotate(x_emb, r, state) if self._use_layer_norm: c = self._activation(aux.layer_norm(x_emb + state_new, "LN_c")) else: c = self._activation(x_emb + state_new) new_h = u * state + (1 - u) * c if self._T_norm != None: new_h = tf.nn.l2_normalize(new_h, 1, epsilon=self._eps) * self._T_norm if self._use_zoneout: new_h = aux.rum_zoneout(new_h, state, self._zoneout_keep_h, self._is_training) return new_h, new_h
def call(self, x, state): with tf.variable_scope(type(self).__name__): h, c = state h_size = self.num_units x_size = x.get_shape().as_list()[1] w_init = aux.orthogonal_initializer(1.0) h_init = aux.orthogonal_initializer(1.0) b_init = tf.constant_initializer(0.0) W_xh = tf.get_variable('W_xh', [x_size, 4 * h_size], initializer=w_init, dtype=tf.float32) W_hh = tf.get_variable('W_hh', [h_size, 4 * h_size], initializer=h_init, dtype=tf.float32) bias = tf.get_variable('bias', [4 * h_size], initializer=b_init, dtype=tf.float32) concat = tf.concat(axis=1, values=[x, h]) # concat for speed. W_full = tf.concat(axis=0, values=[W_xh, W_hh]) concat = tf.matmul(concat, W_full) + bias concat = aux.layer_norm_all(concat, 4, h_size, 'ln') # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = tf.split(axis=1, num_or_size_splits=4, value=concat) new_c = c * tf.sigmoid(f + self.f_bias) + tf.sigmoid(i) * tf.tanh(j) new_h = tf.tanh(aux.layer_norm(new_c, 'ln_c')) * tf.sigmoid(o) if self.use_zoneout: new_h, new_c = aux.zoneout(new_h, new_c, h, c, self.zoneout_keep_h, self.zoneout_keep_c, self.is_training) return new_h, (new_h, new_c)
def call(self, inputs, state): #extract the associative memory and the state size_batch = tf.shape(state)[0] assoc_mem, state = tf.split( state, [self._hidden_size * self._hidden_size, self._hidden_size], 1) assoc_mem = tf.reshape( assoc_mem, [size_batch, self._hidden_size, self._hidden_size]) with vs.variable_scope("gates"): bias_ones = self._bias_initializer if self._bias_initializer is None: dtype = [a.dtype for a in [inputs, state]][0] bias_ones = init_ops.constant_initializer(1.0, dtype=dtype) value = fully_connected( inputs=tf.concat([inputs, state], axis=1), num_outputs=2 * self._hidden_size, activation_fn=None, biases_initializer=bias_ones, weights_initializer=aux.rum_ortho_initializer()) r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1) u = sigmoid(u) if self._use_layer_norm: concat = tf.concat([r, u], 1) concat = aux.layer_norm_all(concat, 2, self._hidden_size, "LN_r_u") r, u = tf.split(concat, 2, 1) with vs.variable_scope("candidate"): x_emb = fully_connected( inputs=inputs, num_outputs=self._hidden_size, activation_fn=None, biases_initializer=self._bias_initializer, weights_initializer=self._kernel_initializer) tmp_rotation = rotation_operator(x_emb, r, self._hidden_size) Rt = tf.matmul(assoc_mem, tmp_rotation) state_new = tf.reshape( tf.matmul( Rt, tf.reshape(state, [size_batch, self._hidden_size, 1])), [size_batch, self._hidden_size]) if self._use_layer_norm: c = self._activation(aux.layer_norm(x_emb + state_new, "LN_c")) else: c = self._activation(x_emb + state_new) new_h = u * state + (1 - u) * c if self._T_norm != None: new_h = tf.nn.l2_normalize(new_h, 1, epsilon=self._eps) * self._T_norm if self._use_zoneout: new_h = aux.rum_zoneout(new_h, state, self._zoneout_keep_h, self._is_training) Rt = tf.reshape(Rt, [size_batch, self._hidden_size * self._hidden_size]) new_state = tf.concat([Rt, new_h], 1) return new_h, new_state