Esempio n. 1
0
    def one_rnn_transform(self, W, h, U, x, b):
        hyp_b = b
        if self.bias_geom == 'eucl':
            hyp_b = util.tf_exp_map_zero(b, self.c_val)

        W_otimes_h = util.tf_mob_mat_mul(W, h, self.c_val)
        U_otimes_x = util.tf_mob_mat_mul(U, x, self.c_val)
        Wh_plus_Ux = util.tf_mob_add(W_otimes_h, U_otimes_x, self.c_val)
        return util.tf_mob_add(Wh_plus_Ux, hyp_b, self.c_val)
Esempio n. 2
0
    def create(labels, embeddings, **kwargs):

        word_vec = embeddings['word']
        char_vec = embeddings['char']
        model = HyperbolicRNNModel()
        model.sess = kwargs.get('sess', tf.Session())

        model.mxlen = kwargs.get('maxs', 100)
        model.maxw = kwargs.get('maxw', 100)

        hsz = int(kwargs['hsz'])
        pdrop = kwargs.get('dropout', 0.5)
        pdrop_in = kwargs.get('dropin', 0.0)
        rnntype = kwargs.get('rnntype', 'blstm')
        print(rnntype)
        layers = kwargs.get('layers', 1)
        model.labels = labels
        model.crf = bool(kwargs.get('crf', False))
        model.crf_mask = bool(kwargs.get('crf_mask', False))
        model.span_type = kwargs.get('span_type')
        model.proj = bool(kwargs.get('proj', False))
        model.feed_input = bool(kwargs.get('feed_input', False))
        model.activation_type = kwargs.get('activation', 'tanh')

        char_dsz = char_vec.dsz
        nc = len(labels)
        model.x = kwargs.get(
            'x', tf.placeholder(tf.int32, [None, model.mxlen], name="x"))
        model.xch = kwargs.get(
            'xch',
            tf.placeholder(tf.int32, [None, model.mxlen, model.maxw],
                           name="xch"))
        model.y = kwargs.get(
            'y', tf.placeholder(tf.int64, [None, model.mxlen], name="y"))
        model.lengths = kwargs.get(
            'lengths', tf.placeholder(tf.int32, [None], name="lengths"))
        model.pkeep = kwargs.get('pkeep',
                                 tf.placeholder(tf.float64, name="pkeep"))
        model.pdrop_value = pdrop
        model.pdropin_value = pdrop_in
        model.word_vocab = {}

        inputs_geom = kwargs.get("inputs_geom", "hyp")
        bias_geom = kwargs.get("bias_geom", "hyp")
        ffnn_geom = kwargs.get("ffnn_geom", "hyp")
        sent_geom = kwargs.get("sent_geom", "hyp")
        mlr_geom = kwargs.get("mlr_geom", "hyp")
        c_val = kwargs.get("c_val", 1.0)
        cell_non_lin = kwargs.get("cell_non_lin",
                                  "id")  #"id/relu/tanh/sigmoid."
        ffnn_non_lin = kwargs.get("ffnn_non_lin", "id")
        cell_type = kwargs.get("cell_type", 'rnn')
        lr_words = kwargs.get("lw_words", 0.01)
        lr_ffnn = kwargs.get("lr_ffnn", 0.01)
        optimizer = kwargs.get("optimizer", "rsgd")
        eucl_clip = kwargs.get("eucl_clip", 1.0)
        hyp_clip = kwargs.get("hyp_clip", 1.0)
        before_mlr_dim = kwargs.get("before_mlr_dim", nc)
        learn_softmax = kwargs.get("learn_softmax", True)
        batch_sz = 10

        print("C_val:", c_val)

        eucl_vars = []
        hyp_vars = []

        if word_vec is not None:
            model.word_vocab = word_vec.vocab

        # model.char_vocab = char_vec.vocab
        seed = np.random.randint(10e8)
        if word_vec is not None:
            # word_embeddings = embed(model.x, len(word_vec.vocab), word_vec.dsz,
            #                         initializer=tf.constant_initializer(word_vec.weights, dtype=tf.float32, verify_shape=True))
            with tf.variable_scope("LUT"):
                W = tf.get_variable("W",
                                    dtype=tf.float64,
                                    initializer=tf.constant_initializer(
                                        word_vec.weights,
                                        dtype=tf.float64,
                                        verify_shape=True),
                                    shape=[len(word_vec.vocab), word_vec.dsz],
                                    trainable=True)
                # e0 = tf.scatter_update(W, tf.constant(0, dtype=tf.int32, shape=[1]), tf.zeros(shape=[1, word_vec.dsz]))
                # with tf.control_dependencies([W]):
                word_embeddings = tf.nn.embedding_lookup(W, model.x)

        # Wch = tf.Variable(tf.constant(char_vec.weights, dtype=tf.float32), name="Wch")
        # ce0 = tf.scatter_update(Wch, tf.constant(0, dtype=tf.int32, shape=[1]), tf.zeros(shape=[1, char_dsz]))

        # word_char, _ = pool_chars(model.xch, Wch, ce0, char_dsz, **kwargs)
        # joint = word_char if word_vec is None else tf.concat(values=[word_embeddings, word_char], axis=2)
        # word_embeddings = tf.Print(word_embeddings, [word_embeddings], message="embeddings")

        embedseq = word_embeddings

        # embedseq = tf.nn.dropout(word_embeddings, model.pkeep)
        # if (mlr_geom == 'hyp'):
        #     embedseq = util.tf_exp_map_zero(embedseq, c_val)

        if cell_type == 'rnn' and sent_geom == 'eucl':
            cell_class = lambda h_dim: tf.contrib.rnn.BasicRNNCell(h_dim)
        if cell_type == 'rnn' and sent_geom == 'hyp':
            cell_class = lambda h_dim, layer: LorentzRNN(
                num_units=h_dim,
                inputs_geom=inputs_geom,
                bias_geom=bias_geom,
                c_val=c_val,
                non_lin=cell_non_lin,
                fix_biases=False,
                fix_matrices=False,
                matrices_init_eye=False,
                dtype=tf.float64,
                layer=layer)
        # elif cell_type == 'gru' and sent_geom == 'hyp':
        #     cell_class = lambda h_dim, layer: rnn_impl.HypGRU(num_units=h_dim,
        #                                                inputs_geom=inputs_geom,
        #                                                bias_geom=bias_geom,
        #                                                c_val=c_val,
        #                                                non_lin=cell_non_lin,
        #                                                fix_biases=False,
        #                                                fix_matrices=False,
        #                                                matrices_init_eye=False,
        #                                                dtype=tf.float64,
        #                                                layer=layer)
        # elif cell_type == 'lstm' and sent_geom == 'hyp':
        #     cell_class = lambda h_dim, layer: rnn_impl.HypLSTM(num_units=h_dim,
        #                                                inputs_geom=inputs_geom,
        #                                                bias_geom=bias_geom,
        #                                                c_val=c_val,
        #                                                non_lin=cell_non_lin,
        #                                                fix_biases=False,
        #                                                fix_matrices=False,
        #                                                matrices_init_eye=False,
        #                                                dtype=tf.float64,
        #                                                layer=layer)
        rnnout = embedseq
        for i in range(layers):
            with tf.variable_scope('rnnLayers', reuse=tf.AUTO_REUSE):
                if rnntype == 'rnn':
                    cell = cell_class(hsz, i)
                    initial_state = cell.zero_state(batch_sz, tf.float64)

                    # rnnout = tf.contrib.rnn.DropoutWrapper(cell)
                    rnnout, state = tf.nn.dynamic_rnn(cell,
                                                    rnnout, \
                                                    sequence_length=model.lengths,
                                                    initial_state=initial_state,
                                                    dtype=tf.float64)

                    eucl_vars += cell.eucl_vars
                    if sent_geom == 'hyp':
                        hyp_vars += cell.hyp_vars

                elif rnntype == 'bi':
                    cell_1 = cell_class(hsz, i)
                    cell_2 = cell_class(hsz, i)

                    init_fw = cell_1.zero_state(batch_sz, tf.float64)
                    init_bw = cell_2.zero_state(batch_sz, tf.float64)

                    rnnout, state = tf.nn.bidirectional_dynamic_rnn(
                        cell_1,
                        cell_2,
                        rnnout,
                        initial_state_fw=init_fw,
                        initial_state_bw=init_bw,
                        sequence_length=model.lengths,
                        dtype=tf.float64)
                    rnnout = tf.concat(axis=2, values=rnnout)

                    eucl_vars += cell_1.eucl_vars + cell_2.eucl_vars
                    if sent_geom == 'hyp':
                        hyp_vars += cell_1.hyp_vars + cell_2.hyp_vars

                else:
                    cell = cell_class(hsz)

                    # rnnout = tf.contrib.rnn.DropoutWrapper(cell)
                    rnnout, state = tf.nn.dynamic_rnn(
                        cell,
                        rnnout,
                        sequence_length=model.lengths,
                        dtype=tf.float64)

                    eucl_vars += cell.eucl_vars
                    if sent_geom == 'hyp':
                        hyp_vars += cell.hyp_vars

        # rnnout = tf.Print(rnnout, [rnnout], message="rnnout")

        tf.summary.histogram('RNN/rnnout', rnnout)

        # # Converts seq to tensor, back to (B,T,W)
        hout = rnnout.get_shape()[-1]
        print(rnnout.get_shape())
        # # Flatten from [B x T x H] - > [BT x H]
        with tf.variable_scope("fc"):
            rnnout_bt_x_h = tf.reshape(rnnout, [-1, hout])
            # rnnout_bt_x_h = tf.Print(rnnout_bt_x_h, [rnnout_bt_x_h], message="rnnout_bt_x_h")

            ################## first feed forward layer ###################

            # Define variables for the first feed-forward layer: W1 * s1 + W2 * s2 + b + bd * d(s1,s2)
            W_ff_s1 = tf.get_variable(
                'W_ff_s1',
                dtype=tf.float64,
                shape=[hout,
                       before_mlr_dim],  # 400, 20 -- 20 number of classes
                initializer=tf.contrib.layers.xavier_initializer(
                    dtype=tf.float64))

            tf.summary.histogram("W_ff_s1", W_ff_s1)

            # b_ff = tf.get_variable('b_ff',
            #                        dtype=tf.float64,
            #                        shape=[1, before_mlr_dim],
            #                        initializer=tf.constant_initializer(0.0))

            # # TODO(MB): ffn should be in hyperbolic space, no?
            eucl_vars += [W_ff_s1]

            # hyp_vars += [b_ff]

            # #### treat W as an update in tangent space
            # # ffnn_s1 = rnnout_bt_x_h + W_ff_s1 + b_ff
            # # cheat for now. i don't know how to multiply these together first
            ffnn_s1 = lorentz.tf_mink_dot_matrix(rnnout_bt_x_h,
                                                 tf.transpose(W_ff_s1))
            # ffnn_s1 = W_ff_s1 +  dotp * rnnout_bt_x_h
            # #### embed back into minkowski space
            # ffnn_s1 = lorentz.tf_exp_map_x(rnnout_bt_x_h, ffnn_s1, c_val)

            # print('ffnn', ffnn_s1.get_shape())
            # tf.summary.histogram("ffnn_s1", ffnn_s1)

            output_ffnn = util.tf_hyp_non_lin(
                ffnn_s1,
                non_lin=ffnn_non_lin,
                hyp_output=True,  #(mlr_geom == 'hyp'),
                c=c_val)
        tf.summary.histogram("output_ffnn", output_ffnn)
        # output_ffnn = tf.Print(output_ffnn, [output_ffnn], message="output_ffnn")
        # output_ffnn = dotp

        # ################## MLR ###################
        # # output_ffnn is batch_size x before_mlr_dim
        if not learn_softmax:
            probs = output_ffnn
        else:
            print("learning softmax in hyperbolic space")
            A_mlr = []
            P_mlr = []
            logits_list = []
            dtype = tf.float64

            print('output shape', output_ffnn.get_shape())

            with tf.variable_scope("hyper_softmax"):
                for cl in range(nc):
                    with tf.variable_scope('mlp'):
                        A_mlr.append(
                            tf.get_variable('A_mlr' + str(cl),
                                            dtype=dtype,
                                            shape=[1, before_mlr_dim],
                                            initializer=tf.contrib.layers.
                                            xavier_initializer()))
                        eucl_vars += [A_mlr[cl]]

                        P_mlr.append(
                            tf.get_variable(
                                'P_mlr' + str(cl),
                                dtype=dtype,
                                shape=[1, before_mlr_dim],
                                initializer=tf.constant_initializer(0.0)))

                        if mlr_geom == 'eucl':
                            eucl_vars += [P_mlr[cl]]
                            logits_list.append(
                                tf.reshape(
                                    util.tf_dot(-P_mlr[cl] + output_ffnn,
                                                A_mlr[cl]), [-1]))

                        elif mlr_geom == 'hyp':
                            hyp_vars += [P_mlr[cl]]
                            minus_p_plus_x = util.tf_mob_add(
                                -P_mlr[cl], output_ffnn, c_val)
                            norm_a = util.tf_norm(A_mlr[cl])
                            lambda_px = util.tf_lambda_x(minus_p_plus_x, c_val)
                            # blow-- P+X == [10, 20] tensor. A_mlr is also [10,20]. px_dot_a is [10, 1]
                            px_dot_a = util.tf_dot(
                                minus_p_plus_x, tf.nn.l2_normalize(A_mlr[cl]))
                            logit = 2. / np.sqrt(c_val) * norm_a * tf.asinh(
                                np.sqrt(c_val) * px_dot_a * lambda_px)

                            logits_list.append(logit)

        probs = tf.stack(logits_list, axis=1)

        print("probs shape", probs.get_shape())
        model.probs = tf.reshape(probs, [-1, model.mxlen, nc])
        print("reshaped probs", model.probs.get_shape())
        tf.summary.histogram("probs", model.probs)

        model.best = tf.argmax(model.probs, 2)

        model.loss = model.create_loss()

        # model.best = tf.argmax(model.probs, axis=1, output_type=tf.int32)
        #     ######################################## OPTIMIZATION ######################################
        all_updates_ops = []

        #     ###### Update Euclidean parameters using Adam.
        optimizer_euclidean_params = tf.train.AdamOptimizer(learning_rate=1e-3)
        eucl_grads = optimizer_euclidean_params.compute_gradients(
            model.loss, eucl_vars)
        capped_eucl_gvs = [(tf.clip_by_norm(grad, eucl_clip), var)
                           for grad, var in eucl_grads]  ###### Clip gradients
        all_updates_ops.append(
            optimizer_euclidean_params.apply_gradients(capped_eucl_gvs))

        ###### Update Hyperbolic parameters, i.e. word embeddings and some biases in our case.
        def rsgd(v, riemannian_g, learning_rate):
            if optimizer == 'rsgd':
                return lorentz.tf_exp_map_x(v,
                                            -model.burn_in_factor *
                                            learning_rate * riemannian_g,
                                            c=c_val)
            else:
                # Use approximate RSGD based on a simple retraction.
                updated_v = v - model.burn_in_factor * learning_rate * riemannian_g
                # Projection op after SGD update. Need to make sure embeddings are inside the unit ball.
                return util.tf_project_hyp_vecs(updated_v, c_val)

        if inputs_geom == 'hyp':
            grads_and_indices_hyp_words = tf.gradients(model.loss, W)
            grads_hyp_words = grads_and_indices_hyp_words[0].values
            # grads_hyp_words = tf.Print(grads_hyp_words, [grads_hyp_words], message="grads_hyp_words")

            repeating_indices = grads_and_indices_hyp_words[0].indices

            unique_indices, idx_in_repeating_indices = tf.unique(
                repeating_indices)
            # unique_indices = tf.Print(unique_indices, [unique_indices], message="unique_indices")
            # idx_in_repeating_indices = tf.Print(idx_in_repeating_indices, [idx_in_repeating_indices], message="idx_in_repeating_indices")

            agg_gradients = tf.unsorted_segment_sum(
                grads_hyp_words, idx_in_repeating_indices,
                tf.shape(unique_indices)[0])

            agg_gradients = tf.clip_by_norm(agg_gradients,
                                            hyp_clip)  ######## Clip gradients
            # agg_gradients = tf.Print(agg_gradients, [agg_gradients], message="agg_gradients")

            unique_word_emb = tf.nn.embedding_lookup(
                W, unique_indices)  # no repetitions here
            # unique_word_emb = tf.Print(unique_word_emb, [unique_word_emb], message="unique_word_emb")

            riemannian_rescaling_factor = util.riemannian_gradient_c(
                unique_word_emb, c=c_val)
            # riemannian_rescaling_factor = tf.Print(riemannian_rescaling_factor, [riemannian_rescaling_factor], message="rescl factor")
            rescaled_gradient = riemannian_rescaling_factor * agg_gradients
            # rescaled_gradient = tf.Print(rescaled_gradient, [rescaled_gradient], message="rescl gradient")
            all_updates_ops.append(
                tf.scatter_update(
                    W, unique_indices,
                    rsgd(unique_word_emb, rescaled_gradient,
                         lr_words)))  # Updated rarely

        if len(hyp_vars) > 0:
            hyp_grads = tf.gradients(model.loss, hyp_vars)
            capped_hyp_grads = [
                tf.clip_by_norm(grad, hyp_clip) for grad in hyp_grads
            ]  ###### Clip gradients

            for i in range(len(hyp_vars)):
                riemannian_rescaling_factor = util.riemannian_gradient_c(
                    hyp_vars[i], c=c_val)
                rescaled_gradient = riemannian_rescaling_factor * capped_hyp_grads[
                    i]
                all_updates_ops.append(
                    tf.assign(hyp_vars[i],
                              rsgd(hyp_vars[i], rescaled_gradient,
                                   lr_ffnn)))  # Updated frequently

        model.all_optimizer_var_updates_op = tf.group(*all_updates_ops)
        print("all ops: ", model.all_optimizer_var_updates_op)

        model.summary_merged = tf.summary.merge_all()

        model.test_summary_writer = tf.summary.FileWriter(
            './runs/hyper/' + str(os.getpid()), model.sess.graph)

        return model
Esempio n. 3
0
    def __call__(self, inputs, state, scope=None):
        with tf.variable_scope(scope or type(self).__name__):
            if not self.built:
                inputs_shape = inputs.get_shape()
                print('Init LSTM cell')
                if inputs_shape[1].value is None:
                    raise ValueError(
                        "Expected inputs.shape[-1] to be known, saw shape: %s"
                        % inputs_shape)
                input_depth = inputs_shape[1].value

                #update operation
                self.Wz = tf.get_variable(
                    'W_z' + str(self.layer),
                    dtype=self.__dtype,
                    shape=[self._num_units, self._num_units],
                    trainable=(not self.fix_matrices),
                    initializer=self.matrix_initializer)
                if not self.fix_matrices:
                    self.eucl_vars.append(self.Wz)
                self.Uz = tf.get_variable('U_z' + str(self.layer),
                                          dtype=self.__dtype,
                                          shape=[input_depth, self._num_units],
                                          trainable=(not self.fix_matrices),
                                          initializer=self.matrix_initializer)
                if not self.fix_matrices:
                    self.eucl_vars.append(self.Uz)
                self.bz = tf.get_variable(
                    'b_z' + str(self.layer),
                    dtype=self.__dtype,
                    shape=[1, self._num_units],
                    trainable=(not self.fix_biases),
                    initializer=tf.constant_initializer(0.0))
                if not self.fix_biases:
                    if self.bias_geom == 'hyp':
                        self.hyp_vars.append(self.bz)
                    else:
                        self.eucl_vars.append(self.bz)
                ###########################################
                #forget operation
                self.Wf = tf.get_variable(
                    'W_f' + str(self.layer),
                    dtype=self.__dtype,
                    shape=[self._num_units, self._num_units],
                    trainable=(not self.fix_matrices),
                    initializer=self.matrix_initializer)
                if not self.fix_matrices:
                    self.eucl_vars.append(self.Wf)
                self.Uf = tf.get_variable('U_f' + str(self.layer),
                                          dtype=self.__dtype,
                                          shape=[input_depth, self._num_units],
                                          trainable=(not self.fix_matrices),
                                          initializer=self.matrix_initializer)
                if not self.fix_matrices:
                    self.eucl_vars.append(self.Uf)
                self.bf = tf.get_variable(
                    'b_f' + str(self.layer),
                    dtype=self.__dtype,
                    shape=[1, self._num_units],
                    trainable=(not self.fix_biases),
                    initializer=tf.constant_initializer(0.0))
                if not self.fix_biases:
                    if self.bias_geom == 'hyp':
                        self.hyp_vars.append(self.bf)
                    else:
                        self.eucl_vars.append(self.bf)
                ###########################################
                #output operation

                self.Wo = tf.get_variable(
                    'W_o' + str(self.layer),
                    dtype=self.__dtype,
                    shape=[self._num_units, self._num_units],
                    trainable=(not self.fix_matrices),
                    initializer=self.matrix_initializer)
                if not self.fix_matrices:
                    self.eucl_vars.append(self.Wo)
                self.Uo = tf.get_variable('U_o' + str(self.layer),
                                          dtype=self.__dtype,
                                          shape=[input_depth, self._num_units],
                                          trainable=(not self.fix_matrices),
                                          initializer=self.matrix_initializer)
                if not self.fix_matrices:
                    self.eucl_vars.append(self.Uo)
                self.bo = tf.get_variable(
                    'b_o' + str(self.layer),
                    dtype=self.__dtype,
                    shape=[1, self._num_units],
                    trainable=(not self.fix_biases),
                    initializer=tf.constant_initializer(0.0))
                if not self.fix_biases:
                    if self.bias_geom == 'hyp':
                        self.hyp_vars.append(self.bo)
                    else:
                        self.eucl_vars.append(self.bo)

                ##########################################
                self.Wc = tf.get_variable(
                    'W_c' + str(self.layer),
                    dtype=self.__dtype,
                    shape=[self._num_units, self._num_units],
                    trainable=(not self.fix_matrices),
                    initializer=self.matrix_initializer)
                if not self.fix_matrices:
                    self.eucl_vars.append(self.Wc)
                self.Uc = tf.get_variable('U_c' + str(self.layer),
                                          dtype=self.__dtype,
                                          shape=[input_depth, self._num_units],
                                          trainable=(not self.fix_matrices),
                                          initializer=self.matrix_initializer)
                if not self.fix_matrices:
                    self.eucl_vars.append(self.Uc)
                self.bc = tf.get_variable(
                    'b_c' + str(self.layer),
                    dtype=self.__dtype,
                    shape=[1, self._num_units],
                    trainable=(not self.fix_biases),
                    initializer=tf.constant_initializer(0.0))
                if not self.fix_biases:
                    if self.bias_geom == 'hyp':
                        self.hyp_vars.append(self.bc)
                    else:
                        self.eucl_vars.append(self.bc)
                ###########################################

                self.built = True

            hyp_x = inputs
            if self.inputs_geom == 'eucl':
                hyp_x = util.tf_exp_map_zero(inputs, self.c_val)

            #update
            i = util.tf_hyp_non_lin(self.one_rnn_transform(
                self.Wz, state, self.Uz, hyp_x, self.bz),
                                    non_lin='sigmoid',
                                    hyp_output=False,
                                    c=self.c_val)

            #forget
            f = util.tf_hyp_non_lin(self.one_rnn_transform(
                self.Wf, state, self.Uf, hyp_x, self.bf),
                                    non_lin='sigmoid',
                                    hyp_output=False,
                                    c=self.c_val)

            #output
            o = util.tf_hyp_non_lin(self.one_rnn_transform(
                self.Wo, state, self.Uo, hyp_x, self.bo),
                                    non_lin='sigmoid',
                                    hyp_output=False,
                                    c=self.c_val)

            # r_point_h = util.tf_mob_pointwise_prod(state, r, self.c_val)
            c_tilde = util.tf_hyp_non_lin(self.one_rnn_transform(
                self.Wc, state, self.Uc, hyp_x, self.bc),
                                          non_lin=self.non_lin,
                                          hyp_output=True,
                                          c=self.c_val)

            c_one = util.tf_mob_pointwise_prod(state, f, self.c_val)
            c_two = util.tf_mob_pointwise_prod(c_tilde, i, self.c_val)
            c = util.tf_mob_add(c_one, c_two, self.c_val)
            c = util.tf_hyp_non_lin(c,
                                    non_lin=self.non_lin,
                                    hyp_output=True,
                                    c=self.c_val)

            new_h = util.tf_mob_pointwise_prod(c, o, self.c_val)
            # minus_h_oplus_htilde = util.tf_mob_add(-state, h_tilde, self.c_val)
            # new_h = util.tf_mob_add(state,
            #                         util.tf_mob_pointwise_prod(minus_h_oplus_htilde, z, self.c_val),
            #                         self.c_val)
        return new_h, c
Esempio n. 4
0
    def __call__(self, inputs, state, scope=None):
        with tf.variable_scope(scope or type(self).__name__):
            if not self.built:
                inputs_shape = inputs.get_shape()
                print('Init GRU cell')
                if inputs_shape[1].value is None:
                    raise ValueError(
                        "Expected inputs.shape[-1] to be known, saw shape: %s"
                        % inputs_shape)
                input_depth = inputs_shape[1].value

                self.Wz = tf.get_variable(
                    'W_z' + str(self.layer),
                    dtype=self.__dtype,
                    shape=[self._num_units, self._num_units],
                    trainable=(not self.fix_matrices),
                    initializer=self.matrix_initializer)
                if not self.fix_matrices:
                    self.eucl_vars.append(self.Wz)
                self.Uz = tf.get_variable('U_z' + str(self.layer),
                                          dtype=self.__dtype,
                                          shape=[input_depth, self._num_units],
                                          trainable=(not self.fix_matrices),
                                          initializer=self.matrix_initializer)
                if not self.fix_matrices:
                    self.eucl_vars.append(self.Uz)
                self.bz = tf.get_variable(
                    'b_z' + str(self.layer),
                    dtype=self.__dtype,
                    shape=[1, self._num_units],
                    trainable=(not self.fix_biases),
                    initializer=tf.constant_initializer(0.0))
                if not self.fix_biases:
                    if self.bias_geom == 'hyp':
                        self.hyp_vars.append(self.bz)
                    else:
                        self.eucl_vars.append(self.bz)
                ###########################################

                self.Wr = tf.get_variable(
                    'W_r' + str(self.layer),
                    dtype=self.__dtype,
                    shape=[self._num_units, self._num_units],
                    trainable=(not self.fix_matrices),
                    initializer=self.matrix_initializer)
                if not self.fix_matrices:
                    self.eucl_vars.append(self.Wr)
                self.Ur = tf.get_variable('U_r' + str(self.layer),
                                          dtype=self.__dtype,
                                          shape=[input_depth, self._num_units],
                                          trainable=(not self.fix_matrices),
                                          initializer=self.matrix_initializer)
                if not self.fix_matrices:
                    self.eucl_vars.append(self.Ur)
                self.br = tf.get_variable(
                    'b_r' + str(self.layer),
                    dtype=self.__dtype,
                    shape=[1, self._num_units],
                    trainable=(not self.fix_biases),
                    initializer=tf.constant_initializer(0.0))
                if not self.fix_biases:
                    if self.bias_geom == 'hyp':
                        self.hyp_vars.append(self.br)
                    else:
                        self.eucl_vars.append(self.br)
                ###########################################

                self.Wh = tf.get_variable(
                    'W_h' + str(self.layer),
                    dtype=self.__dtype,
                    shape=[self._num_units, self._num_units],
                    trainable=(not self.fix_matrices),
                    initializer=self.matrix_initializer)
                if not self.fix_matrices:
                    self.eucl_vars.append(self.Wh)
                self.Uh = tf.get_variable('U_h' + str(self.layer),
                                          dtype=self.__dtype,
                                          shape=[input_depth, self._num_units],
                                          trainable=(not self.fix_matrices),
                                          initializer=self.matrix_initializer)
                if not self.fix_matrices:
                    self.eucl_vars.append(self.Uh)
                self.bh = tf.get_variable(
                    'b_h' + str(self.layer),
                    dtype=self.__dtype,
                    shape=[1, self._num_units],
                    trainable=(not self.fix_biases),
                    initializer=tf.constant_initializer(0.0))
                if not self.fix_biases:
                    if self.bias_geom == 'hyp':
                        self.hyp_vars.append(self.bh)
                    else:
                        self.eucl_vars.append(self.bh)
                ###########################################

                self.built = True

            hyp_x = inputs
            if self.inputs_geom == 'eucl':
                hyp_x = util.tf_exp_map_zero(inputs, self.c_val)

            z = util.tf_hyp_non_lin(self.one_rnn_transform(
                self.Wz, state, self.Uz, hyp_x, self.bz),
                                    non_lin='sigmoid',
                                    hyp_output=False,
                                    c=self.c_val)

            r = util.tf_hyp_non_lin(self.one_rnn_transform(
                self.Wr, state, self.Ur, hyp_x, self.br),
                                    non_lin='sigmoid',
                                    hyp_output=False,
                                    c=self.c_val)

            r_point_h = util.tf_mob_pointwise_prod(state, r, self.c_val)
            h_tilde = util.tf_hyp_non_lin(self.one_rnn_transform(
                self.Wh, r_point_h, self.Uh, hyp_x, self.bh),
                                          non_lin=self.non_lin,
                                          hyp_output=True,
                                          c=self.c_val)

            minus_h_oplus_htilde = util.tf_mob_add(-state, h_tilde, self.c_val)
            new_h = util.tf_mob_add(
                state,
                util.tf_mob_pointwise_prod(minus_h_oplus_htilde, z,
                                           self.c_val), self.c_val)
        return new_h, new_h