def Orthonorm(x, name=None): ''' Builds keras layer that handles orthogonalization of x x: an n x d input matrix name: name of the keras layer returns: a keras layer instance. during evaluation, the instance returns an n x d orthogonal matrix if x is full rank and not singular ''' # get dimensionality of x d = x.get_shape().as_list()[-1] # compute orthogonalizing matrix ortho_weights = orthonorm_op(x) # create variable that holds this matrix ortho_weights_store = K.variable(np.zeros((d, d))) # create op that saves matrix into variable ortho_weights_update = tf.assign(ortho_weights_store, ortho_weights, name='ortho_weights_update') # switch between stored and calculated weights based on training or validation l = Lambda(lambda x: K.in_train_phase(K.dot(x, ortho_weights), K.dot(x, ortho_weights_store)), name=name) l.add_update(ortho_weights_update) return l
def call(self, inputs, mask=None, **kwargs): """Core implemention of soft attention Args: inputs (object): input tensor. Returns: object: weighted sum of input tensors. """ attention = K.tanh(K.dot(inputs, self.W) + self.b) attention = K.dot(attention, self.q) attention = K.squeeze(attention, axis=2) if mask is None: attention = K.exp(attention) else: attention = K.exp(attention) * K.cast(mask, dtype="float32") attention_weight = attention / ( K.sum(attention, axis=-1, keepdims=True) + K.epsilon()) attention_weight = K.expand_dims(attention_weight) weighted_input = inputs * attention_weight return K.sum(weighted_input, axis=1)
def call(self, QKVs): """Core logic of multi-head self attention. Args: QKVs (list): inputs of multi-head self attention i.e. query, key and value. Returns: object: ouput tensors. """ if len(QKVs) == 3: Q_seq, K_seq, V_seq = QKVs Q_len, V_len = None, None elif len(QKVs) == 5: Q_seq, K_seq, V_seq, Q_len, V_len = QKVs Q_seq = K.dot(Q_seq, self.WQ) Q_seq = K.reshape(Q_seq, shape=(-1, K.shape(Q_seq)[1], self.multiheads, self.head_dim)) Q_seq = K.permute_dimensions(Q_seq, pattern=(0, 2, 1, 3)) K_seq = K.dot(K_seq, self.WK) K_seq = K.reshape(K_seq, shape=(-1, K.shape(K_seq)[1], self.multiheads, self.head_dim)) K_seq = K.permute_dimensions(K_seq, pattern=(0, 2, 1, 3)) V_seq = K.dot(V_seq, self.WV) V_seq = K.reshape(V_seq, shape=(-1, K.shape(V_seq)[1], self.multiheads, self.head_dim)) V_seq = K.permute_dimensions(V_seq, pattern=(0, 2, 1, 3)) A = einsum("abij, abkj -> abik", Q_seq, K_seq) / K.sqrt( K.cast(self.head_dim, dtype="float32")) A = K.permute_dimensions( A, pattern=(0, 3, 2, 1) ) # A.shape=[batch_size,K_sequence_length,Q_sequence_length,self.multiheads] A = self.Mask(A, V_len, "add") A = K.permute_dimensions(A, pattern=(0, 3, 2, 1)) if self.mask_right: ones = K.ones_like(A[:1, :1]) lower_triangular = K.tf.matrix_band_part(ones, num_lower=-1, num_upper=0) mask = (ones - lower_triangular) * 1e12 A = A - mask A = K.softmax(A) O_seq = einsum("abij, abjk -> abik", A, V_seq) O_seq = K.permute_dimensions(O_seq, pattern=(0, 2, 1, 3)) O_seq = K.reshape(O_seq, shape=(-1, K.shape(O_seq)[1], self.output_dim)) O_seq = self.Mask(O_seq, Q_len, "mul") return O_seq
def call(self, x): WQ = K.dot(x, self.kernel[0]) WK = K.dot(x, self.kernel[1]) WV = K.dot(x, self.kernel[2]) # print("WQ.shape",WQ.shape) # print("K.permute_dimensions(WK, [0, 2, 1]).shape",K.permute_dimensions(WK, [0, 2, 1]).shape) QK = K.batch_dot(WQ, K.permute_dimensions(WK, [0, 2, 1])) QK = QK / (64**0.5) QK = K.softmax(QK) # print("QK.shape",QK.shape) V = K.batch_dot(QK, WV) return V
def gru_with_z_gate(x, weight): h_tm1, inputs, r, hh = x[0], x[1], x[2], x[3] weight = K.variable(weight) units = h_tm1.shape[-1] kernel_z = weight[:units, :units] recurrent_kernel_z = weight[units:units * 2, :units] input_bias_z = weight[units * 2, :units] # Change to 1 dim. x_z = K.bias_add(K.dot(inputs, kernel_z), input_bias_z) recurrent_z = K.dot(h_tm1, recurrent_kernel_z) z_without_activate = x_z + recurrent_z z = hard_sigmoid(z_without_activate) h = z * h_tm1 + (1 - z) * hh #return h return z
def gru_with_r_gate(x, weight): h_tm1, inputs, z, x_h, split_recurrent_h = x[0], x[1], x[2], x[3], x[4] weight = K.variable(weight) units = h_tm1.shape[-1] kernel_r = weight[:units, units:units * 2] recurrent_kernel_r = weight[units:units * 2, units:units * 2] input_bias_r = weight[units * 2, units:units * 2] # Change to 1 dim. x_r = K.bias_add(K.dot(inputs, kernel_r), input_bias_r) recurrent_r = K.dot(h_tm1, recurrent_kernel_r) r_without_activate = x_r + recurrent_r r = hard_sigmoid(r_without_activate) #r = hard_sigmoid(x_r + recurrent_r) # Recompute recurrent_h by two parts. r_unsqueeze = K.expand_dims(r, axis=-1) recompute_recurrent_h = K.sum(r_unsqueeze * split_recurrent_h, axis=1) hh = tanh(x_h + recompute_recurrent_h) h = z * h_tm1 + (1 - z) * hh #return h return r
def orthonorm_op(x, epsilon=1e-7): ''' Computes a matrix that orthogonalizes the input matrix x x: an n x d input matrix eps: epsilon to prevent nonzero values in the diagonal entries of x returns: a d x d matrix, ortho_weights, which orthogonalizes x by right multiplication ''' x_2 = K.dot(K.transpose(x), x) x_2 += K.eye(K.int_shape(x)[1]) * epsilon L = tf.cholesky(x_2) ortho_weights = tf.transpose(tf.matrix_inverse(L)) * tf.sqrt( tf.cast(tf.shape(x)[0], dtype=K.floatx())) return ortho_weights
def get_GRU_components(inputs, states, weight): units = weight[0].shape[0] kernel = K.variable(weight[0]) # shape = (input_dim, self.units * 3) recurrent_kernel = K.variable( weight[1]) # shape = (self.units, self.units * 3) bias = K.variable(weight[2]) # bias_shape = (3 * self.units,) inputs = K.variable(inputs) # Not sure. h_tm1 = K.variable(states) # Previous memory state. # Update gate. kernel_z = kernel[:, :units] recurrent_kernel_z = recurrent_kernel[:, :units] input_bias_z = bias[:units] # Reset gate. kernel_r = kernel[:, units:units * 2] recurrent_kernel_r = recurrent_kernel[:, units:units * 2] input_bias_r = bias[units:units * 2] # New gate. kernel_h = kernel[:, units * 2:] recurrent_kernel_h = recurrent_kernel[:, units * 2:] input_bias_h = bias[units * 2:] x_z = K.bias_add(K.dot(inputs, kernel_z), input_bias_z) x_r = K.bias_add(K.dot(inputs, kernel_r), input_bias_r) x_h = K.bias_add(K.dot(inputs, kernel_h), input_bias_h) recurrent_z = K.dot(h_tm1, recurrent_kernel_z) recurrent_r = K.dot(h_tm1, recurrent_kernel_r) z = hard_sigmoid(x_z + recurrent_z) # Recurrent activation = 'hard_sigmoid'. r = hard_sigmoid(x_r + recurrent_r) recurrent_h = K.dot(r * h_tm1, recurrent_kernel_h) # Get split part of recurrent_h. split_recurrent_h = K.expand_dims(h_tm1, axis=-1) * recurrent_kernel_h r_unsqueeze = K.expand_dims(r, axis=-1) recompute_recurrent_h = K.sum(r_unsqueeze * split_recurrent_h, axis=1) #print(recurrent_h.shape, h_tm1.shape, recurrent_kernel_h.shape, split_recurrent_h.shape) #print(K.get_value(recompute_recurrent_h)[0, :3], np.mean(K.get_value(recompute_recurrent_h))) #print(K.get_value(recurrent_h)[0, :3], np.mean(K.get_value(recurrent_h))) delta = np.mean( np.abs(K.get_value(recompute_recurrent_h) - K.get_value(recurrent_h))) print("delta =", delta, np.mean(K.get_value(recompute_recurrent_h)), np.mean(K.get_value(recurrent_h))) assert delta < 1e-6, "r gate is wrong." hh = tanh(x_h + recurrent_h) # Activation = 'tanh'. # Previous and candidate state mixed by update gate. h = z * h_tm1 + (1 - z) * hh return K.get_value(h_tm1), K.get_value(h), K.get_value(z), K.get_value( r), K.get_value(hh), K.get_value(x_h), K.get_value(split_recurrent_h)
def call(self, input): f = list() fn = list() c_dict = keypoint_connections() for i in range(21): f.append(self.conv2(input)) for i in range(21): c = list() c.append(f[i]) for j in c_dict[i]: c.append(f[j]) c = K.concatenate(tuple(c), axis=-1) h = self.conv2a[i](c) g = self.conv2b[i](h) l = K.dot(g, self.W[i]) print(l.shape) fn.append( K.sum(K.concatenate((f[i], l), axis=-1), axis=-1, keepdims=True)) return K.concatenate(tuple(fn), axis=-1)
def call(self, x, mask=None): features_dim = self.features_dim step_dim = self.step_dim e = K.reshape( K.dot(K.reshape(x, (-1, features_dim)), K.reshape(self.W, (features_dim, 1))), (-1, step_dim)) # e = K.dot(x, self.W) if self.bias: e += self.b e = K.tanh(e) a = K.exp(e) # apply mask after the exp. will be re-normalized next if mask is not None: # cast the mask to floatX to avoid float64 upcasting in theano a *= K.cast(mask, K.floatx()) # in some cases especially in the early stages of training the sum may be almost zero # and this results in NaN's. A workaround is to add a very small positive number ε to the sum. a /= K.cast(K.sum(a, axis=1, keepdims=True) + K.epsilon(), K.floatx()) a = K.expand_dims(a) c = K.sum(a * x, axis=1) return c
def call(self, inputs, states): prev_output = states[0] h = K.dot(inputs, self.kernel) output = h + K.dot(prev_output, self.recurrent_kernel) return output, [output]