Exemplo n.º 1
0
def dynamic_rnn(cell,
                inputs,
                sequence_length=None,
                initial_state=None,
                dtype=None,
                parallel_iterations=None,
                swap_memory=False,
                time_major=False,
                scope=None):
    assert not time_major  # TODO : to be implemented later!
    flat_inputs = flatten(inputs, 2)  # [-1, J, d]
    flat_len = None if sequence_length is None else tf.cast(
        flatten(sequence_length, 0), 'int64')

    flat_outputs, final_state = _dynamic_rnn(
        cell,
        flat_inputs,
        sequence_length=flat_len,
        initial_state=initial_state,
        dtype=dtype,
        parallel_iterations=parallel_iterations,
        swap_memory=swap_memory,
        time_major=time_major,
        scope=scope)

    outputs = reconstruct(flat_outputs, inputs, 2)
    return outputs, final_state
Exemplo n.º 2
0
 def __call__(self,
              inputs,
              seq_len,
              return_last_state=False,
              keep_prob=None,
              is_train=None):
     with tf.variable_scope(self.scope):
         if return_last_state:
             _, (_, output) = _dynamic_rnn(self.cell,
                                           inputs,
                                           sequence_length=seq_len,
                                           dtype=tf.float32)
         else:
             output, _ = _dynamic_rnn(self.cell,
                                      inputs,
                                      sequence_length=seq_len,
                                      dtype=tf.float32)
         output = dropout(output, keep_prob, is_train)
         # print("Output", output.get_shape())
         return output
Exemplo n.º 3
0
def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
                dtype=None, parallel_iterations=None, swap_memory=False,
                time_major=False, scope=None):
    assert not time_major  # TODO : to be implemented later!
    flat_inputs = flatten(inputs, 2)  # [-1, J, d]
    flat_len = None if sequence_length is None else tf.cast(flatten(sequence_length, 0), 'int64')

    flat_outputs, final_state = _dynamic_rnn(cell, flat_inputs, sequence_length=flat_len,
                                             initial_state=initial_state, dtype=dtype,
                                             parallel_iterations=parallel_iterations, swap_memory=swap_memory,
                                             time_major=time_major, scope=scope)

    outputs = reconstruct(flat_outputs, inputs, 2)
    return outputs, final_state