Esempio n. 1
0
    def __call__(self, is_train, scope=None):
        activation = get_keras_activation(self.activation)
        recurrent_activation = get_keras_activation(self.recurrent_activation)
        kernel_initializer = get_keras_initialization(self.kernel_initializer)
        recurrent_initializer = get_keras_initialization(self.recurrent_initializer)
        if activation is None or kernel_initializer is None \
                or recurrent_initializer is None or recurrent_activation is None:
            raise ValueError()

        cell = InitializedLSTMCell(self.num_units, kernel_initializer,
                                   recurrent_initializer, activation,
                                   recurrent_activation, self.forget_bias,
                                   self.keep_recurrent_probs, is_train, scope)
        return cell
Esempio n. 2
0
 def __call__(self, is_train, scope=None):
     activation = get_keras_activation(self.activation)
     recurrent_initializer = get_keras_initialization(self.recurrent_initializer)
     kernel_initializer = get_keras_initialization(self.kernel_initializer)
     candidate_initializer = get_keras_initialization(self.candidate_initializer)
     return GRUCell(self.num_units, tf.constant_initializer(self.bais_init),
                    kernel_initializer, recurrent_initializer, candidate_initializer, activation)
    def _distance_logits(self, x, keys):
        init = get_keras_initialization(self.init)
        key_w = tf.get_variable("key_w",
                                shape=(keys.shape.as_list()[-1],
                                       self.projected_size),
                                initializer=init,
                                dtype=tf.float32)
        key_logits = tf.tensordot(keys, key_w,
                                  axes=[[2], [0]
                                        ])  # (batch, key_len, projected_size)

        if self.shared_project:
            x_w = key_w
        else:
            x_w = tf.get_variable("x_w",
                                  shape=(x.shape.as_list()[-1],
                                         self.projected_size),
                                  initializer=init,
                                  dtype=tf.float32)

        x_logits = tf.tensordot(x, x_w,
                                axes=[[2],
                                      [0]])  # (batch, x_len, projected_size)

        summed = tf.expand_dims(x_logits, axis=2) + tf.expand_dims(
            key_logits, axis=1)  # (batch, key_len, x_len, poject_size)

        summed = get_keras_activation(self.activation)(summed)

        combine_w = tf.get_variable("combine_w",
                                    shape=self.projected_size,
                                    initializer=init,
                                    dtype=tf.float32)

        return tf.tensordot(summed, combine_w,
                            axes=[[3], [0]])  # (batch, key_len, x_len)