コード例 #1
0
ファイル: gru_cell.py プロジェクト: tabilab-dip/BOUN-PARS
 def __call__(self, inputs, state, scope=None):
   """ """
   
   with tf.variable_scope(scope or type(self).__name__):
     cell_tm1, hidden_tm1 = tf.split(state, 2, axis=1)
     input_list = [inputs, hidden_tm1]
     with tf.variable_scope('Gates'):
       gates = linear(inputs_list,
                      self.output_size,
                      add_bias=True,
                      n_splits=2,
                      moving_params=self.moving_params)
       update_act, reset_act = gates
       update_gate = gate(update_act-self.forget_bias)
       reset_gate = gate(reset_act)
       reset_state = reset_gate * hidden_tm1
     input_list = [inputs, reset_state]
     with tf.variable_scope('Candidate'):
       hidden_act = linear(input_list,
                           self.output_size,
                           add_bias=True,
                           moving_params=self.moving_params)
       hidden_tilde = self.recur_func(hidden_act)
     cell_t = update_gate * cell_tm1 + (1-update_gate) * hidden_tilde
   return cell_t, tf.concat([cell_t, cell_t], 1)
コード例 #2
0
ファイル: cif_lstm_cell.py プロジェクト: tapika/Parser-v2
    def __call__(self, inputs, state, scope=None):
        """ """

        with tf.variable_scope(scope or type(self).__name__):
            cell_tm1, hidden_tm1 = tf.split(state, 2, axis=1)
            input_list = [inputs, hidden_tm1]
            lin = linear(input_list,
                         self.output_size,
                         add_bias=True,
                         n_splits=3,
                         moving_params=self.moving_params)
            cell_act, update_act, output_act = lin

            cell_tilde_t = cell_act
            update_gate = gate(update_act - self.forget_bias)
            output_gate = gate(output_act)
            cell_t = update_gate * cell_tilde_t + (1 - update_gate) * cell_tm1
            hidden_tilde_t = self.recur_func(cell_t)
            hidden_t = hidden_tilde_t * output_gate

            return hidden_t, tf.concat([cell_t, hidden_t], 1)