def call(self, u_vecs): if self.share_weights: u_hat_vecs = K.conv1d(u_vecs, self.W) #(400,2*16) else: u_hat_vecs = K.local_conv1d(u_vecs, self.W, [1], [1]) batch_size = K.shape(u_vecs)[0] #none input_num_capsule = K.shape(u_vecs)[1] #400 u_hat_vecs = K.reshape( u_hat_vecs, (batch_size, input_num_capsule, self.num_capsule, self.dim_capsule)) #(none,400,2,16) u_hat_vecs = K.permute_dimensions(u_hat_vecs, (0, 2, 1, 3)) #(none,2,400,16) b = K.zeros_like( u_hat_vecs[:, :, :, 0] ) # shape = [None, num_capsule, input_num_capsule](none,2,400) for i in range(self.routings): b = K.permute_dimensions( b, (0, 2, 1) ) # shape = [None, input_num_capsule, num_capsule](none,400,2) c = K.softmax(b) c = K.permute_dimensions(c, (0, 2, 1)) b = K.permute_dimensions(b, (0, 2, 1)) outputs = self.activation(K.batch_dot(c, u_hat_vecs, [2, 2])) if i < self.routings - 1: b += K.batch_dot(outputs, u_hat_vecs, [2, 3]) return outputs
def routing(self, u_vecs): if self.share_weights: u_hat_vecs = K.conv1d(u_vecs, self.W) else: u_hat_vecs = K.local_conv1d(u_vecs, self.W, [1], [1]) batch_size = K.shape(u_vecs)[0] input_num_capsule = K.shape(u_vecs)[1] u_hat_vecs = K.reshape(u_hat_vecs, (batch_size, input_num_capsule, self.num_capsule, self.dim_capsule)) u_hat_vecs = K.permute_dimensions(u_hat_vecs, (0, 2, 1, 3)) # final u_hat_vecs.shape = [None, num_capsule, input_num_capsule, dim_capsule] b = K.zeros_like( u_hat_vecs[:, :, :, 0]) # shape = [None, num_capsule, input_num_capsule] for i in range(self.routings): c = K.softmax(b) s = K.batch_dot(c, u_hat_vecs, [2, 2]) v = self.squash(s) if i < self.routings - 1: b = b + K.batch_dot(v, u_hat_vecs, [2, 3]) poses = v activations = K.sqrt(K.sum(K.square(poses), 2)) return poses, activations
def call(self, x): height, width, in_channels = x.shape.as_list()[1:] if self.groups is None: if in_channels % self.groups_factor: raise ValueError("%s %% %s" % (in_channels, self.groups_factor)) self.groups = in_channels // self.groups_factor channels_per_group = in_channels // self.groups x = K.reshape(x, [-1, height, width, self.groups, channels_per_group]) x = K.permute_dimensions(x, (0, 1, 2, 4, 3)) # transpose x = K.reshape(x, [-1, height, width, in_channels]) return x
def step(self, inputs, states): h_tm1 = states[0] c_tm1 = states[1] dp_mask = states[2] rec_dp_mask = states[3] x_input = states[4] # alignment model h_att = K.repeat(h_tm1, self.timestep_dim) att = _time_distributed_dense(x_input, self.attention_weights, self.attention_bias, output_dim=K.int_shape( self.attention_weights)[1]) attention_ = self.attention_activation( K.dot(h_att, self.attention_recurrent_weights) + att) attention_ = K.squeeze( K.dot(attention_, self.attention_recurrent_bias), 2) alpha = K.exp(attention_) if dp_mask is not None: alpha *= dp_mask[0] alpha /= K.sum(alpha, axis=1, keepdims=True) alpha_r = K.repeat(alpha, self.input_dim) alpha_r = K.permute_dimensions(alpha_r, (0, 2, 1)) # make context vector (soft attention after Bahdanau et al.) z_hat = x_input * alpha_r context_sequence = z_hat z_hat = K.sum(z_hat, axis=1) if self.implementation == 2: z = K.dot(inputs * dp_mask[0], self.kernel) z += K.dot(h_tm1 * rec_dp_mask[0], self.recurrent_kernel) z += K.dot(z_hat, self.attention_kernel) if self.use_bias: z = K.bias_add(z, self.bias) z0 = z[:, :self.units] z1 = z[:, self.units:2 * self.units] z2 = z[:, 2 * self.units:3 * self.units] z3 = z[:, 3 * self.units:] i = self.recurrent_activation(z0) f = self.recurrent_activation(z1) c = f * c_tm1 + i * self.activation(z2) o = self.recurrent_activation(z3) else: if self.implementation == 0: x_i = inputs[:, :self.units] x_f = inputs[:, self.units:2 * self.units] x_c = inputs[:, 2 * self.units:3 * self.units] x_o = inputs[:, 3 * self.units:] elif self.implementation == 1: x_i = K.dot(inputs * dp_mask[0], self.kernel_i) + self.bias_i x_f = K.dot(inputs * dp_mask[1], self.kernel_f) + self.bias_f x_c = K.dot(inputs * dp_mask[2], self.kernel_c) + self.bias_c x_o = K.dot(inputs * dp_mask[3], self.kernel_o) + self.bias_o else: raise ValueError('Unknown `implementation` mode.') i = self.recurrent_activation( x_i + K.dot(h_tm1 * rec_dp_mask[0], self.recurrent_kernel_i) + K.dot(z_hat, self.attention_i)) f = self.recurrent_activation( x_f + K.dot(h_tm1 * rec_dp_mask[1], self.recurrent_kernel_f) + K.dot(z_hat, self.attention_f)) c = f * c_tm1 + i * self.activation( x_c + K.dot(h_tm1 * rec_dp_mask[2], self.recurrent_kernel_c) + K.dot(z_hat, self.attention_c)) o = self.recurrent_activation( x_o + K.dot(h_tm1 * rec_dp_mask[3], self.recurrent_kernel_o) + K.dot(z_hat, self.attention_o)) h = o * self.activation(c) if 0 < self.dropout + self.recurrent_dropout: h._uses_learning_phase = True if self.return_attention: return context_sequence, [h, c] else: return h, [h, c]