示例#1
0
 def new_func(*args, **kwargs):
   """Deprecation wrapper."""
   # TODO(apassos) figure out a way to have reasonable performance with
   # deprecation warnings and eager mode.
   if is_in_graph_mode.IS_IN_GRAPH_MODE() and _PRINT_DEPRECATION_WARNINGS:
     invalid_args = []
     named_args = tf_inspect.getcallargs(func, *args, **kwargs)
     for arg_name, spec in iter(deprecated_positions.items()):
       if (spec.position < len(args) and
           not (spec.has_ok_value and
                _same_value(named_args[arg_name], spec.ok_value))):
         invalid_args.append(arg_name)
     if is_varargs_deprecated and len(args) > len(arg_spec.args):
       invalid_args.append(arg_spec.varargs)
     if is_kwargs_deprecated and kwargs:
       invalid_args.append(arg_spec.varkw)
     for arg_name in deprecated_arg_names:
       if (arg_name in kwargs and
           not (deprecated_positions[arg_name].has_ok_value and
                _same_value(named_args[arg_name],
                            deprecated_positions[arg_name].ok_value))):
         invalid_args.append(arg_name)
     for arg_name in invalid_args:
       if (func, arg_name) not in _PRINTED_WARNING:
         if warn_once:
           _PRINTED_WARNING[(func, arg_name)] = True
         logging.warning(
             'From %s: calling %s (from %s) with %s is deprecated and will '
             'be removed %s.\nInstructions for updating:\n%s',
             _call_location(), decorator_utils.get_qualified_name(func),
             func.__module__, arg_name,
             'in a future version' if date is None else ('after %s' % date),
             instructions)
   return func(*args, **kwargs)
def raw_rnn(cell,
            loop_fn,
            parallel_iterations=None,
            swap_memory=False,
            scope=None):
    """
    raw_rnn adapted from the original tensorflow implementation
    (https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/ops/rnn.py)
    to emit arbitrarily nested states for each time step (concatenated along the time axis)
    in addition to the outputs at each timestep and the final state

    returns (
        states for all timesteps,
        outputs for all timesteps,
        final cell state,
    )
    """
    assert_like_rnncell("Raw rnn cell", cell)
    if not callable(loop_fn):
        raise TypeError("loop_fn must be a callable")

    parallel_iterations = parallel_iterations or 32

    # Create a new scope in which the caching device is either
    # determined by the parent scope, or is set to place the cached
    # Variable using the same placement as for the rest of the RNN.
    with vs.variable_scope(scope or "rnn") as varscope:
        if is_in_graph_mode.IS_IN_GRAPH_MODE():
            if varscope.caching_device is None:
                varscope.set_caching_device(lambda op: op.device)

        time = constant_op.constant(0, dtype=dtypes.int32)
        (elements_finished, next_input, initial_state, emit_structure,
         init_loop_state) = loop_fn(time, None, None, None)
        flat_input = nest.flatten(next_input)

        # Need a surrogate loop state for the while_loop if none is available.
        loop_state = (init_loop_state if init_loop_state is not None else
                      constant_op.constant(0, dtype=dtypes.int32))

        input_shape = [input_.get_shape() for input_ in flat_input]
        static_batch_size = input_shape[0][0]

        for input_shape_i in input_shape:
            # Static verification that batch sizes all match
            static_batch_size.merge_with(input_shape_i[0])

        batch_size = static_batch_size.value
        const_batch_size = batch_size
        if batch_size is None:
            batch_size = array_ops.shape(flat_input[0])[0]

        nest.assert_same_structure(initial_state, cell.state_size)
        state = initial_state
        flat_state = nest.flatten(state)
        flat_state = [ops.convert_to_tensor(s) for s in flat_state]
        state = nest.pack_sequence_as(structure=state,
                                      flat_sequence=flat_state)

        if emit_structure is not None:
            flat_emit_structure = nest.flatten(emit_structure)
            flat_emit_size = [
                emit.shape
                if emit.shape.is_fully_defined() else array_ops.shape(emit)
                for emit in flat_emit_structure
            ]
            flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
        else:
            emit_structure = cell.output_size
            flat_emit_size = nest.flatten(emit_structure)
            flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size)

        flat_state_size = [
            s.shape if s.shape.is_fully_defined() else array_ops.shape(s)
            for s in flat_state
        ]
        flat_state_dtypes = [s.dtype for s in flat_state]

        flat_emit_ta = [
            tensor_array_ops.TensorArray(
                dtype=dtype_i,
                dynamic_size=True,
                element_shape=(tensor_shape.TensorShape([
                    const_batch_size
                ]).concatenate(_maybe_tensor_shape_from_tensor(size_i))),
                size=0,
                name="rnn_output_%d" % i)
            for i, (dtype_i,
                    size_i) in enumerate(zip(flat_emit_dtypes, flat_emit_size))
        ]
        emit_ta = nest.pack_sequence_as(structure=emit_structure,
                                        flat_sequence=flat_emit_ta)
        flat_zero_emit = [
            array_ops.zeros(_concat(batch_size, size_i), dtype_i)
            for size_i, dtype_i in zip(flat_emit_size, flat_emit_dtypes)
        ]

        zero_emit = nest.pack_sequence_as(structure=emit_structure,
                                          flat_sequence=flat_zero_emit)

        flat_state_ta = [
            tensor_array_ops.TensorArray(
                dtype=dtype_i,
                dynamic_size=True,
                element_shape=(tensor_shape.TensorShape([
                    const_batch_size
                ]).concatenate(_maybe_tensor_shape_from_tensor(size_i))),
                size=0,
                name="rnn_state_%d" % i)
            for i, (
                dtype_i,
                size_i) in enumerate(zip(flat_state_dtypes, flat_state_size))
        ]
        state_ta = nest.pack_sequence_as(structure=state,
                                         flat_sequence=flat_state_ta)

        def condition(unused_time, elements_finished, *_):
            return math_ops.logical_not(math_ops.reduce_all(elements_finished))

        def body(time, elements_finished, current_input, state_ta, emit_ta,
                 state, loop_state):
            (next_output, cell_state) = cell(current_input, state)

            nest.assert_same_structure(state, cell_state)
            nest.assert_same_structure(cell.output_size, next_output)

            next_time = time + 1
            (next_finished, next_input, next_state, emit_output,
             next_loop_state) = loop_fn(next_time, next_output, cell_state,
                                        loop_state)

            nest.assert_same_structure(state, next_state)
            nest.assert_same_structure(current_input, next_input)
            nest.assert_same_structure(emit_ta, emit_output)

            # If loop_fn returns None for next_loop_state, just reuse the previous one.
            loop_state = loop_state if next_loop_state is None else next_loop_state

            def _copy_some_through(current, candidate):
                """Copy some tensors through via array_ops.where."""
                def copy_fn(cur_i, cand_i):
                    # TensorArray and scalar get passed through.
                    if isinstance(cur_i, tensor_array_ops.TensorArray):
                        return cand_i
                    if cur_i.shape.ndims == 0:
                        return cand_i
                    # Otherwise propagate the old or the new value.
                    with ops.colocate_with(cand_i):
                        return array_ops.where(elements_finished, cur_i,
                                               cand_i)

                return nest.map_structure(copy_fn, current, candidate)

            emit_output = _copy_some_through(zero_emit, emit_output)
            next_state = _copy_some_through(state, next_state)

            emit_ta = nest.map_structure(lambda ta, emit: ta.write(time, emit),
                                         emit_ta, emit_output)
            state_ta = nest.map_structure(
                lambda ta, state: ta.write(time, state), state_ta, next_state)

            elements_finished = math_ops.logical_or(elements_finished,
                                                    next_finished)

            return (next_time, elements_finished, next_input, state_ta,
                    emit_ta, next_state, loop_state)

        returned = control_flow_ops.while_loop(
            condition,
            body,
            loop_vars=[
                time, elements_finished, next_input, state_ta, emit_ta, state,
                loop_state
            ],
            parallel_iterations=parallel_iterations,
            swap_memory=swap_memory)

        (state_ta, emit_ta, final_state, final_loop_state) = returned[-4:]

        flat_states = nest.flatten(state_ta)
        flat_states = [
            array_ops.transpose(ta.stack(), (1, 0, 2)) for ta in flat_states
        ]
        states = nest.pack_sequence_as(structure=state_ta,
                                       flat_sequence=flat_states)

        flat_outputs = nest.flatten(emit_ta)
        flat_outputs = [
            array_ops.transpose(ta.stack(), (1, 0, 2)) for ta in flat_outputs
        ]
        outputs = nest.pack_sequence_as(structure=emit_ta,
                                        flat_sequence=flat_outputs)

        return (states, outputs, final_state)