def gru(opts, inputs): if opts.popnn: # PopnnGRU uses a direct poplibs implementation gru_cell = rnn_ops.PopnnGRU(opts.hidden_size, dtype=inputs.dtype) # The input is [timesteps, batch_size, input_size], # where input_size is equal to hidden_size return gru_cell(inputs, training=opts.train) else: # The input for GRUCell is instead [batch_size, input_size], # where input_size is equal to hidden_size gru_cell = tf.nn.rnn_cell.GRUCell(opts.hidden_size) # Dynamic_rnn creates a loop that passes slices across timesteps to GRUCell. # This is expanded in tensorflow creating a less optimal solution than PopnnGRU. return tf.nn.dynamic_rnn(cell=gru_cell, inputs=inputs, dtype=inputs.dtype, time_major=True)
def gru(partials): gru_ = rnn_ops.PopnnGRU(256) partial_t = tf.transpose(partials, [1, 0, 2]) gru_outputs_t, _ = gru_(partial_t) return tf.transpose(gru_outputs_t, [1, 0, 2])