def call(self, u_vecs):
        if self.share_weights:
            u_hat_vecs = K.conv1d(u_vecs, self.W)  #(400,2*16)
        else:
            u_hat_vecs = K.local_conv1d(u_vecs, self.W, [1], [1])

        batch_size = K.shape(u_vecs)[0]  #none
        input_num_capsule = K.shape(u_vecs)[1]  #400
        u_hat_vecs = K.reshape(
            u_hat_vecs, (batch_size, input_num_capsule, self.num_capsule,
                         self.dim_capsule))  #(none,400,2,16)
        u_hat_vecs = K.permute_dimensions(u_hat_vecs,
                                          (0, 2, 1, 3))  #(none,2,400,16)

        b = K.zeros_like(
            u_hat_vecs[:, :, :, 0]
        )  # shape = [None, num_capsule, input_num_capsule](none,2,400)
        for i in range(self.routings):
            b = K.permute_dimensions(
                b, (0, 2, 1)
            )  # shape = [None, input_num_capsule, num_capsule](none,400,2)
            c = K.softmax(b)
            c = K.permute_dimensions(c, (0, 2, 1))
            b = K.permute_dimensions(b, (0, 2, 1))
            outputs = self.activation(K.batch_dot(c, u_hat_vecs, [2, 2]))
            if i < self.routings - 1:
                b += K.batch_dot(outputs, u_hat_vecs, [2, 3])

        return outputs
示例#2
0
    def routing(self, u_vecs):
        if self.share_weights:
            u_hat_vecs = K.conv1d(u_vecs, self.W)
        else:
            u_hat_vecs = K.local_conv1d(u_vecs, self.W, [1], [1])

        batch_size = K.shape(u_vecs)[0]
        input_num_capsule = K.shape(u_vecs)[1]
        u_hat_vecs = K.reshape(u_hat_vecs,
                               (batch_size, input_num_capsule,
                                self.num_capsule, self.dim_capsule))
        u_hat_vecs = K.permute_dimensions(u_hat_vecs, (0, 2, 1, 3))
        # final u_hat_vecs.shape = [None, num_capsule, input_num_capsule, dim_capsule]

        b = K.zeros_like(
            u_hat_vecs[:, :, :,
                       0])  # shape = [None, num_capsule, input_num_capsule]
        for i in range(self.routings):
            c = K.softmax(b)
            s = K.batch_dot(c, u_hat_vecs, [2, 2])
            v = self.squash(s)

            if i < self.routings - 1:
                b = b + K.batch_dot(v, u_hat_vecs, [2, 3])

        poses = v
        activations = K.sqrt(K.sum(K.square(poses), 2))
        return poses, activations
示例#3
0
    def call(self, x):
        height, width, in_channels = x.shape.as_list()[1:]

        if self.groups is None:
            if in_channels % self.groups_factor:
                raise ValueError("%s %% %s" %
                                 (in_channels, self.groups_factor))

            self.groups = in_channels // self.groups_factor

        channels_per_group = in_channels // self.groups

        x = K.reshape(x, [-1, height, width, self.groups, channels_per_group])
        x = K.permute_dimensions(x, (0, 1, 2, 4, 3))  # transpose
        x = K.reshape(x, [-1, height, width, in_channels])

        return x
示例#4
0
    def step(self, inputs, states):
        h_tm1 = states[0]
        c_tm1 = states[1]
        dp_mask = states[2]
        rec_dp_mask = states[3]
        x_input = states[4]

        # alignment model
        h_att = K.repeat(h_tm1, self.timestep_dim)
        att = _time_distributed_dense(x_input,
                                      self.attention_weights,
                                      self.attention_bias,
                                      output_dim=K.int_shape(
                                          self.attention_weights)[1])
        attention_ = self.attention_activation(
            K.dot(h_att, self.attention_recurrent_weights) + att)
        attention_ = K.squeeze(
            K.dot(attention_, self.attention_recurrent_bias), 2)

        alpha = K.exp(attention_)

        if dp_mask is not None:
            alpha *= dp_mask[0]

        alpha /= K.sum(alpha, axis=1, keepdims=True)
        alpha_r = K.repeat(alpha, self.input_dim)
        alpha_r = K.permute_dimensions(alpha_r, (0, 2, 1))

        # make context vector (soft attention after Bahdanau et al.)
        z_hat = x_input * alpha_r
        context_sequence = z_hat
        z_hat = K.sum(z_hat, axis=1)

        if self.implementation == 2:
            z = K.dot(inputs * dp_mask[0], self.kernel)
            z += K.dot(h_tm1 * rec_dp_mask[0], self.recurrent_kernel)
            z += K.dot(z_hat, self.attention_kernel)

            if self.use_bias:
                z = K.bias_add(z, self.bias)

            z0 = z[:, :self.units]
            z1 = z[:, self.units:2 * self.units]
            z2 = z[:, 2 * self.units:3 * self.units]
            z3 = z[:, 3 * self.units:]

            i = self.recurrent_activation(z0)
            f = self.recurrent_activation(z1)
            c = f * c_tm1 + i * self.activation(z2)
            o = self.recurrent_activation(z3)
        else:
            if self.implementation == 0:
                x_i = inputs[:, :self.units]
                x_f = inputs[:, self.units:2 * self.units]
                x_c = inputs[:, 2 * self.units:3 * self.units]
                x_o = inputs[:, 3 * self.units:]
            elif self.implementation == 1:
                x_i = K.dot(inputs * dp_mask[0], self.kernel_i) + self.bias_i
                x_f = K.dot(inputs * dp_mask[1], self.kernel_f) + self.bias_f
                x_c = K.dot(inputs * dp_mask[2], self.kernel_c) + self.bias_c
                x_o = K.dot(inputs * dp_mask[3], self.kernel_o) + self.bias_o
            else:
                raise ValueError('Unknown `implementation` mode.')

            i = self.recurrent_activation(
                x_i + K.dot(h_tm1 * rec_dp_mask[0], self.recurrent_kernel_i) +
                K.dot(z_hat, self.attention_i))
            f = self.recurrent_activation(
                x_f + K.dot(h_tm1 * rec_dp_mask[1], self.recurrent_kernel_f) +
                K.dot(z_hat, self.attention_f))
            c = f * c_tm1 + i * self.activation(
                x_c + K.dot(h_tm1 * rec_dp_mask[2], self.recurrent_kernel_c) +
                K.dot(z_hat, self.attention_c))
            o = self.recurrent_activation(
                x_o + K.dot(h_tm1 * rec_dp_mask[3], self.recurrent_kernel_o) +
                K.dot(z_hat, self.attention_o))
        h = o * self.activation(c)
        if 0 < self.dropout + self.recurrent_dropout:
            h._uses_learning_phase = True

        if self.return_attention:
            return context_sequence, [h, c]
        else:
            return h, [h, c]