Example #1
0
    def _time_step(time, output_ta_t, alpha_ta_t, attn_ids_ta_t, lmbda_ta_t, *state):
        input_t = input_ta.read(time)
        # Restore some shape information
        input_t.set_shape([const_batch_size, const_depth])

        # Pack state back up for use by cell
        state = tfutils.packed_state(structure=state_size, state=state)

        call_cell = lambda: cell(input_t, state)

        (output, alpha, attn_ids, lmbdas, new_state) = _rnn_step(
            time=time,
            sequence_length=sequence_length,
            zero_output=zero_output,
            zero_alpha=zero_alpha,
            zero_attn_ids=zero_attn_ids,
            zero_lmbdas=zero_lmbdas,
            state=state,
            call_cell=call_cell,
            state_size=state_size,
        )

        # Pack state if using state tuples
        new_state = tuple(tfutils.unpacked_state(new_state))

        output_ta_t = output_ta_t.write(time, output)
        alpha_ta_t = alpha_ta_t.write(time, alpha)
        attn_ids_ta_t = attn_ids_ta_t.write(time, attn_ids)
        lmbda_ta_t = lmbda_ta_t.write(time, lmbdas)

        return (time + 1, output_ta_t, alpha_ta_t, attn_ids_ta_t, lmbda_ta_t) + new_state
Example #2
0
    def _time_step(time, output_ta_t, alpha_ta_t, attn_ids_ta_t, lmbda_ta_t, *state):
        input_t = input_ta.read(time)
        # Restore some shape information
        input_t.set_shape([const_batch_size, const_depth])

        # Pack state back up for use by cell
        state = tfutils.packed_state(structure=state_size, state=state)

        call_cell = lambda: cell(input_t, state)

        (output, alpha, attn_ids, lmbdas, new_state) = _rnn_step(
            time=time,
            sequence_length=sequence_length,
            zero_output=zero_output,
            zero_alpha=zero_alpha,
            zero_attn_ids=zero_attn_ids,
            zero_lmbdas=zero_lmbdas,
            state=state,
            call_cell=call_cell,
            state_size=state_size,
        )

        # Pack state if using state tuples
        new_state = tuple(tfutils.unpacked_state(new_state))

        output_ta_t = output_ta_t.write(time, output)
        alpha_ta_t = alpha_ta_t.write(time, alpha)
        attn_ids_ta_t = attn_ids_ta_t.write(time, attn_ids)
        lmbda_ta_t = lmbda_ta_t.write(time, lmbdas)

        return (time + 1, output_ta_t, alpha_ta_t, attn_ids_ta_t, lmbda_ta_t) + new_state
Example #3
0
def _rnn_step(time, sequence_length, zero_output, zero_alpha, zero_attn_ids,
              zero_lmbdas, state, call_cell, state_size):
    # Convert state to a list for ease of use
    state = list(tfutils.unpacked_state(state))
    state_shape = [s.get_shape() for s in state]

    def _copy_some_through(new_output, new_alpha, new_attn_ids, new_lmbdas,
                           new_state):
        # Use broadcasting select to determine which values should get
        # the previous state & zero output, and which values should get
        # a calculated state & output.

        # Alpha needs to be (batch, tasks, k)
        copy_cond = (time >= sequence_length)
        return ([
            tf.select(copy_cond, zero_output, new_output),
            tf.select(copy_cond, zero_alpha, new_alpha),  # (batch, tasks, k)
            tf.select(copy_cond, zero_attn_ids, new_attn_ids),
            tf.select(copy_cond, zero_lmbdas, new_lmbdas)
        ] + [
            tf.select(copy_cond, old_s, new_s)
            for (old_s, new_s) in zip(state, new_state)
        ])

    new_output, new_alpha, new_attn_ids, new_lmbdas, new_state = call_cell()
    new_state = list(tfutils.unpacked_state(new_state))

    final_output_and_state = _copy_some_through(new_output, new_alpha,
                                                new_attn_ids, new_lmbdas,
                                                new_state)

    (final_output, final_alpha, final_attn_ids, final_lmbdas,
     final_state) = (final_output_and_state[0], final_output_and_state[1],
                     final_output_and_state[2], final_output_and_state[3],
                     final_output_and_state[4:])

    final_output.set_shape(zero_output.get_shape())
    final_alpha.set_shape(zero_alpha.get_shape())
    final_attn_ids.set_shape(zero_attn_ids.get_shape())
    final_lmbdas.set_shape(zero_lmbdas.get_shape())

    for final_state_i, state_shape_i in zip(final_state, state_shape):
        final_state_i.set_shape(state_shape_i)

    return (final_output, final_alpha, final_attn_ids, final_lmbdas,
            tfutils.packed_state(structure=state_size, state=final_state))
Example #4
0
def _rnn_step(time, sequence_length, zero_output, zero_alpha, zero_attn_ids, zero_lmbdas, state, call_cell, state_size):
    # Convert state to a list for ease of use
    state = list(tfutils.unpacked_state(state))
    state_shape = [s.get_shape() for s in state]

    def _copy_some_through(new_output, new_alpha, new_attn_ids, new_lmbdas, new_state):
        # Use broadcasting select to determine which values should get
        # the previous state & zero output, and which values should get
        # a calculated state & output.

        # Alpha needs to be (batch, tasks, k)
        copy_cond = (time >= sequence_length)
        return ([tf.select(copy_cond, zero_output, new_output),
                 tf.select(copy_cond, zero_alpha, new_alpha), # (batch, tasks, k)
                 tf.select(copy_cond, zero_attn_ids, new_attn_ids),
                 tf.select(copy_cond, zero_lmbdas, new_lmbdas)] +
                [tf.select(copy_cond, old_s, new_s)
                 for (old_s, new_s) in zip(state, new_state)])

    new_output, new_alpha, new_attn_ids, new_lmbdas, new_state = call_cell()
    new_state = list(tfutils.unpacked_state(new_state))

    final_output_and_state = _copy_some_through(new_output, new_alpha, new_attn_ids, new_lmbdas, new_state)

    (final_output, final_alpha, final_attn_ids, final_lmbdas, final_state) = (
        final_output_and_state[0], final_output_and_state[1], final_output_and_state[2],
        final_output_and_state[3], final_output_and_state[4:])

    final_output.set_shape(zero_output.get_shape())
    final_alpha.set_shape(zero_alpha.get_shape())
    final_attn_ids.set_shape(zero_attn_ids.get_shape())
    final_lmbdas.set_shape(zero_lmbdas.get_shape())

    for final_state_i, state_shape_i in zip(final_state, state_shape):
        final_state_i.set_shape(state_shape_i)

    return (
        final_output,
        final_alpha,
        final_attn_ids,
        final_lmbdas,
        tfutils.packed_state(structure=state_size, state=final_state))
Example #5
0
def _dynamic_attention_rnn_loop(cell, inputs, initial_state, parallel_iterations,
                                swap_memory, sequence_length, attn_length, num_tasks,
                                batch_size):
    state = initial_state

    # Construct an initial output
    input_shape = tf.shape(inputs)
    (time_steps, _, _) = tf.unpack(input_shape, 3)

    inputs_got_shape = inputs.get_shape().with_rank(3)
    (const_time_steps, const_batch_size, const_depth) = inputs_got_shape.as_list()

    # Prepare dynamic conditional copying of state & output
    zero_output = tf.zeros(tf.pack([batch_size, cell.output_size]), inputs.dtype)
    zero_attn_ids = zero_alpha = tf.zeros([batch_size, num_tasks-1, attn_length], inputs.dtype)
    zero_lmbdas = tf.zeros([batch_size, num_tasks], tf.float32)

    time = tf.constant(0, dtype=tf.int32, name="time")

    state_size = cell.state_size

    state = tfutils.unpacked_state(state)

    with tf.op_scope([], "dynamic_rnn") as scope:
        base_name = scope

    def create_ta(name, dtype=None):
        dtype = dtype or inputs.dtype
        return tf.TensorArray(dtype=dtype, size=time_steps, tensor_array_name=base_name + name)

    output_ta = create_ta("output")
    alpha_ta = create_ta("alpha", tf.float32)
    attn_ids_ta = create_ta("attn_ids")
    lmbda_ta = create_ta("lmbdas", tf.float32)
    input_ta = create_ta("input")

    input_ta = input_ta.unpack(inputs)

    def _time_step(time, output_ta_t, alpha_ta_t, attn_ids_ta_t, lmbda_ta_t, *state):
        input_t = input_ta.read(time)
        # Restore some shape information
        input_t.set_shape([const_batch_size, const_depth])

        # Pack state back up for use by cell
        state = tfutils.packed_state(structure=state_size, state=state)

        call_cell = lambda: cell(input_t, state)

        (output, alpha, attn_ids, lmbdas, new_state) = _rnn_step(
            time=time,
            sequence_length=sequence_length,
            zero_output=zero_output,
            zero_alpha=zero_alpha,
            zero_attn_ids=zero_attn_ids,
            zero_lmbdas=zero_lmbdas,
            state=state,
            call_cell=call_cell,
            state_size=state_size,
        )

        # Pack state if using state tuples
        new_state = tuple(tfutils.unpacked_state(new_state))

        output_ta_t = output_ta_t.write(time, output)
        alpha_ta_t = alpha_ta_t.write(time, alpha)
        attn_ids_ta_t = attn_ids_ta_t.write(time, attn_ids)
        lmbda_ta_t = lmbda_ta_t.write(time, lmbdas)

        return (time + 1, output_ta_t, alpha_ta_t, attn_ids_ta_t, lmbda_ta_t) + new_state

    final_loop_vars = tf.while_loop(
        cond=lambda time, *_: time < time_steps,
        body=_time_step,
        loop_vars=(time, output_ta, alpha_ta, attn_ids_ta, lmbda_ta) + tuple(state),
        parallel_iterations=parallel_iterations,
        swap_memory=swap_memory)

    (output_final_ta, alpha_final_ta, attn_ids_final_ta, lmbda_final_ta, final_state) = \
        (final_loop_vars[1], final_loop_vars[2], final_loop_vars[3], final_loop_vars[4], final_loop_vars[5:])

    final_outputs = output_final_ta.pack()
    final_alphas = alpha_final_ta.pack()
    final_attn_ids = attn_ids_final_ta.pack()
    final_lmbdas = lmbda_final_ta.pack()
    # Restore some shape information
    final_outputs.set_shape([const_time_steps, const_batch_size, cell.output_size])
    final_alphas.set_shape([const_time_steps, const_batch_size, num_tasks-1, attn_length])
    final_attn_ids.set_shape([const_time_steps, const_batch_size, num_tasks-1, attn_length])
    final_lmbdas.set_shape([const_time_steps, const_batch_size, num_tasks])

    # Unpack final state if not using state tuples.
    final_state = tfutils.packed_state(structure=cell.state_size, state=final_state)

    return final_outputs, final_alphas, final_attn_ids, final_lmbdas, final_state
Example #6
0
def _dynamic_attention_rnn_loop(cell, inputs, initial_state, parallel_iterations,
                                swap_memory, sequence_length, attn_length, num_tasks,
                                batch_size):
    state = initial_state

    # Construct an initial output
    input_shape = tf.shape(inputs)
    (time_steps, _, _) = tf.unstack(input_shape, 3)

    inputs_got_shape = inputs.get_shape().with_rank(3)
    (const_time_steps, const_batch_size, const_depth) = inputs_got_shape.as_list()

    # Prepare dynamic conditional copying of state & output
    zero_output = tf.zeros(tf.stack([batch_size, cell.output_size]), inputs.dtype)
    zero_attn_ids = zero_alpha = tf.zeros([batch_size, num_tasks-1, attn_length], inputs.dtype)
    zero_lmbdas = tf.zeros([batch_size, num_tasks], tf.float32)

    time = tf.constant(0, dtype=tf.int64, name="time")

    state_size = cell.state_size

    state = tfutils.unpacked_state(state)

    with tf.name_scope(values=[], name="dynamic_rnn") as scope:
        base_name = scope

    def create_ta(name, dtype=None):
        dtype = dtype or inputs.dtype
        return tf.TensorArray(dtype=dtype, size=time_steps, tensor_array_name=base_name + name)

    output_ta = create_ta("output")
    alpha_ta = create_ta("alpha", tf.float32)
    attn_ids_ta = create_ta("attn_ids")
    lmbda_ta = create_ta("lmbdas", tf.float32)
    input_ta = create_ta("input")

    input_ta = input_ta.unpack(inputs)

    def _time_step(time, output_ta_t, alpha_ta_t, attn_ids_ta_t, lmbda_ta_t, *state):
        input_t = input_ta.read(time)
        # Restore some shape information
        input_t.set_shape([const_batch_size, const_depth])

        # Pack state back up for use by cell
        state = tfutils.packed_state(structure=state_size, state=state)

        call_cell = lambda: cell(input_t, state)

        (output, alpha, attn_ids, lmbdas, new_state) = _rnn_step(
            time=time,
            sequence_length=sequence_length,
            zero_output=zero_output,
            zero_alpha=zero_alpha,
            zero_attn_ids=zero_attn_ids,
            zero_lmbdas=zero_lmbdas,
            state=state,
            call_cell=call_cell,
            state_size=state_size,
        )

        # Pack state if using state tuples
        new_state = tuple(tfutils.unpacked_state(new_state))

        output_ta_t = output_ta_t.write(time, output)
        alpha_ta_t = alpha_ta_t.write(time, alpha)
        attn_ids_ta_t = attn_ids_ta_t.write(time, attn_ids)
        lmbda_ta_t = lmbda_ta_t.write(time, lmbdas)

        return (time + 1, output_ta_t, alpha_ta_t, attn_ids_ta_t, lmbda_ta_t) + new_state

    final_loop_vars = tf.while_loop(
        cond=lambda time, *_: time < time_steps,
        body=_time_step,
        loop_vars=(time, output_ta, alpha_ta, attn_ids_ta, lmbda_ta) + tuple(state),
        parallel_iterations=parallel_iterations,
        swap_memory=swap_memory)

    (output_final_ta, alpha_final_ta, attn_ids_final_ta, lmbda_final_ta, final_state) = \
        (final_loop_vars[1], final_loop_vars[2], final_loop_vars[3], final_loop_vars[4], final_loop_vars[5:])

    final_outputs = output_final_ta.pack()
    final_alphas = alpha_final_ta.pack()
    final_attn_ids = attn_ids_final_ta.pack()
    final_lmbdas = lmbda_final_ta.pack()
    # Restore some shape information
    final_outputs.set_shape([const_time_steps, const_batch_size, cell.output_size])
    final_alphas.set_shape([const_time_steps, const_batch_size, num_tasks-1, attn_length])
    final_attn_ids.set_shape([const_time_steps, const_batch_size, num_tasks-1, attn_length])
    final_lmbdas.set_shape([const_time_steps, const_batch_size, num_tasks])

    # Unpack final state if not using state tuples.
    final_state = tfutils.packed_state(structure=cell.state_size, state=final_state)

    return final_outputs, final_alphas, final_attn_ids, final_lmbdas, final_state