Пример #1
0
def recurrent_net(
        net,
        cell_net,
        inputs,
        initial_cell_inputs,
        links,
        timestep=None,
        scope=None,
        outputs_with_grads=(0, ),
        recompute_blobs_on_backward=None,
):
    '''
    net: the main net operator should be added to

    cell_net: cell_net which is executed in a recurrent fasion

    inputs: sequences to be fed into the recurrent net. Currently only one input
    is supported. It has to be in a format T x N x (D1...Dk) where T is lengths
    of the sequence. N is a batch size and (D1...Dk) are the rest of dimentions

    initial_cell_inputs: inputs of the cell_net for the 0 timestamp.
    Format for each input is:
        (cell_net_input_name, external_blob_with_data)

    links: a dictionary from cell_net input names in moment t+1 and
    output names of moment t. Currently we assume that each output becomes
    an input for the next timestep.

    timestep: name of the timestep blob to be used. If not provided "timestep"
    is used.

    scope: Internal blobs are going to be scoped in a format
    <scope_name>/<blob_name>
    If not provided we generate a scope name automatically

    outputs_with_grads : position indices of output blobs which will receive
    error gradient (from outside recurrent network) during backpropagation

    recompute_blobs_on_backward: specify a list of blobs that will be
                 recomputed for backward pass, and thus need not to be
                 stored for each forward timestep.
    '''
    assert len(inputs) == 1, "Only one input blob is supported so far"

    # Validate scoping
    for einp in cell_net.Proto().external_input:
        assert einp.startswith(CurrentNameScope()), \
            '''
            Cell net external inputs are not properly scoped, use
            AddScopedExternalInputs() when creating them
            '''

    input_blobs = [str(i[0]) for i in inputs]
    initial_input_blobs = [str(x[1]) for x in initial_cell_inputs]
    op_name = net.NextName('recurrent')

    def s(name):
        # We have to manually scope due to our internal/external blob
        # relationships.
        scope_name = op_name if scope is None else scope
        return "{}/{}".format(str(scope_name), str(name))

    # determine inputs that are considered to be references
    # it is those that are not referred to in inputs or initial_cell_inputs
    known_inputs = map(str, input_blobs + initial_input_blobs)
    known_inputs += [str(x[0]) for x in initial_cell_inputs]
    if timestep is not None:
        known_inputs.append(str(timestep))
    references = [
        core.BlobReference(b) for b in cell_net.Proto().external_input
        if b not in known_inputs
    ]

    inner_outputs = list(cell_net.Proto().external_output)
    # These gradients are expected to be available during the backward pass
    inner_outputs_map = {o: o + '_grad' for o in inner_outputs}

    # compute the backward pass of the cell net
    backward_ops, backward_mapping = core.GradientRegistry.GetBackwardPass(
        cell_net.Proto().op, inner_outputs_map)
    backward_mapping = {str(k): v for k, v in backward_mapping.items()}
    backward_cell_net = core.Net("RecurrentBackwardStep")
    del backward_cell_net.Proto().op[:]

    if recompute_blobs_on_backward is not None:
        # Insert operators to re-compute the specified blobs.
        # They are added in the same order as for the forward pass, thus
        # the order is correct.
        recompute_blobs_on_backward = set(
            [str(b) for b in recompute_blobs_on_backward])
        for op in cell_net.Proto().op:
            if not recompute_blobs_on_backward.isdisjoint(set(op.output)):
                backward_cell_net.Proto().op.extend([op])
                assert set(op.output).issubset(recompute_blobs_on_backward), \
                       'Outputs {} are output by op but not recomputed: {}'.format(
                            set(op.output) - recompute_blobs_on_backward,
                            op
                       )
    else:
        recompute_blobs_on_backward = set()

    backward_cell_net.Proto().op.extend(backward_ops)
    # compute blobs used but not defined in the backward pass
    backward_ssa, backward_blob_versions = core.get_ssa(
        backward_cell_net.Proto())
    undefined = core.get_undefined_blobs(backward_ssa)

    # also add to the output list the intermediate outputs of fwd_step that
    # are used by backward.
    ssa, blob_versions = core.get_ssa(cell_net.Proto())
    scratches = [
        blob for (blob, ver) in blob_versions.items() if ver > 0
        and blob in undefined and blob not in cell_net.Proto().external_output
    ]
    backward_cell_net.Proto().external_input.extend(scratches)

    all_inputs = [i[1] for i in inputs] + [x[1] for x in initial_cell_inputs
                                           ] + references
    all_outputs = []

    cell_net.Proto().type = 'simple'
    backward_cell_net.Proto().type = 'simple'

    # Internal arguments used by RecurrentNetwork operator

    # Links are in the format blob_name, recurrent_states, offset.
    # In the moment t we know that corresponding data block is at
    # t + offset position in the recurrent_states tensor
    forward_links = []
    backward_links = []

    # Aliases are used to expose outputs to external world
    # Format (internal_blob, external_blob, offset)
    # Negative offset stands for going from the end,
    # positive - from the beginning
    aliases = []

    # States held inputs to the cell net
    recurrent_states = []

    for cell_input, _ in initial_cell_inputs:
        cell_input = str(cell_input)
        # Recurrent_states is going to be (T + 1) x ...
        # It stores all inputs and outputs of the cell net over time.
        # Or their gradients in the case of the backward pass.
        state = s(cell_input + "_states")
        states_grad = state + "_grad"
        cell_output = links[str(cell_input)]
        forward_links.append((cell_input, state, 0))
        forward_links.append((cell_output, state, 1))
        backward_links.append((cell_output + "_grad", states_grad, 1))

        backward_cell_net.Proto().external_input.append(
            str(cell_output) + "_grad")
        aliases.append((state, cell_output + "_all", 1))
        aliases.append((state, cell_output + "_last", -1))
        all_outputs.extend([cell_output + "_all", cell_output + "_last"])

        recurrent_states.append(state)

        recurrent_input_grad = cell_input + "_grad"
        if not backward_blob_versions.get(recurrent_input_grad, 0):
            # If nobody writes to this recurrent input gradient, we need
            # to make sure it gets to the states grad blob after all.
            # We do this by using backward_links which triggers an alias
            # This logic is being used for example in a SumOp case
            backward_links.append(
                (backward_mapping[cell_input], states_grad, 0))
        else:
            backward_links.append((cell_input + "_grad", states_grad, 0))

    for reference in references:
        # Similar to above, in a case of a SumOp we need to write our parameter
        # gradient to an external blob. In this case we can be sure that
        # reference + "_grad" is a correct parameter name as we know how
        # RecurrentNetworkOp gradient schema looks like.
        reference_grad = reference + "_grad"
        if (reference in backward_mapping
                and reference_grad != str(backward_mapping[reference])):
            # We can use an Alias because after each timestep
            # RNN op adds value from reference_grad into and _acc blob
            # which accumulates gradients for corresponding parameter accross
            # timesteps. Then in the end of RNN op these two are being
            # swaped and reference_grad blob becomes a real blob instead of
            # being an alias
            backward_cell_net.Alias(backward_mapping[reference],
                                    reference_grad)

    for input_t, input_blob in inputs:
        forward_links.append((str(input_t), str(input_blob), 0))
        backward_links.append(
            (backward_mapping[str(input_t)], str(input_blob) + "_grad", 0))
    backward_cell_net.Proto().external_input.extend(
        cell_net.Proto().external_input)
    backward_cell_net.Proto().external_input.extend(
        cell_net.Proto().external_output)

    def unpack_triple(x):
        if x:
            a, b, c = zip(*x)
            return a, b, c
        return [], [], []

    # Splitting to separate lists so we can pass them to c++
    # where we ensemle them back
    link_internal, link_external, link_offset = unpack_triple(forward_links)
    backward_link_internal, backward_link_external, backward_link_offset = \
        unpack_triple(backward_links)
    alias_src, alias_dst, alias_offset = unpack_triple(aliases)

    params = [x for x in references if x in backward_mapping.keys()]
    recurrent_inputs = [str(x[1]) for x in initial_cell_inputs]

    global _workspace_seq
    results = net.RecurrentNetwork(
        all_inputs,
        all_outputs + [s("step_workspaces")],
        param=map(all_inputs.index, params),
        alias_src=alias_src,
        alias_dst=map(str, alias_dst),
        alias_offset=alias_offset,
        recurrent_states=recurrent_states,
        initial_recurrent_state_ids=map(all_inputs.index, recurrent_inputs),
        link_internal=map(str, link_internal),
        link_external=map(str, link_external),
        link_offset=link_offset,
        backward_link_internal=map(str, backward_link_internal),
        backward_link_external=map(str, backward_link_external),
        backward_link_offset=backward_link_offset,
        step_net=str(cell_net.Proto()),
        backward_step_net=str(backward_cell_net.Proto()),
        timestep="timestep" if timestep is None else str(timestep),
        outputs_with_grads=outputs_with_grads,
        recompute_blobs_on_backward=map(str, recompute_blobs_on_backward))
    # The last output is a list of step workspaces,
    # which is only needed internally for gradient propogation
    return results[:-1]
Пример #2
0
def recurrent_net(
        net, cell_net, inputs, initial_cell_inputs,
        links, timestep=None, scope=None, outputs_with_grads=(0,),
        recompute_blobs_on_backward=None, forward_only=False,
):
    '''
    net: the main net operator should be added to

    cell_net: cell_net which is executed in a recurrent fasion

    inputs: sequences to be fed into the recurrent net. Currently only one input
    is supported. It has to be in a format T x N x (D1...Dk) where T is lengths
    of the sequence. N is a batch size and (D1...Dk) are the rest of dimentions

    initial_cell_inputs: inputs of the cell_net for the 0 timestamp.
    Format for each input is:
        (cell_net_input_name, external_blob_with_data)

    links: a dictionary from cell_net input names in moment t+1 and
    output names of moment t. Currently we assume that each output becomes
    an input for the next timestep.

    timestep: name of the timestep blob to be used. If not provided "timestep"
    is used.

    scope: Internal blobs are going to be scoped in a format
    <scope_name>/<blob_name>
    If not provided we generate a scope name automatically

    outputs_with_grads : position indices of output blobs which will receive
    error gradient (from outside recurrent network) during backpropagation

    recompute_blobs_on_backward: specify a list of blobs that will be
                 recomputed for backward pass, and thus need not to be
                 stored for each forward timestep.

    forward_only: if True, only forward steps are executed
    '''
    assert len(inputs) == 1, "Only one input blob is supported so far"

    # Validate scoping
    for einp in cell_net.Proto().external_input:
        assert einp.startswith(CurrentNameScope()), \
            '''
            Cell net external inputs are not properly scoped, use
            AddScopedExternalInputs() when creating them
            '''

    input_blobs = [str(i[0]) for i in inputs]
    initial_input_blobs = [str(x[1]) for x in initial_cell_inputs]
    op_name = net.NextName('recurrent')

    def s(name):
        # We have to manually scope due to our internal/external blob
        # relationships.
        scope_name = op_name if scope is None else scope
        return "{}/{}".format(str(scope_name), str(name))

    # determine inputs that are considered to be references
    # it is those that are not referred to in inputs or initial_cell_inputs
    known_inputs = [str(b) for b in input_blobs + initial_input_blobs]
    known_inputs += [str(x[0]) for x in initial_cell_inputs]
    if timestep is not None:
        known_inputs.append(str(timestep))
    references = [
        core.BlobReference(b) for b in cell_net.Proto().external_input
        if b not in known_inputs]

    inner_outputs = list(cell_net.Proto().external_output)
    # These gradients are expected to be available during the backward pass
    inner_outputs_map = {o: o + '_grad' for o in inner_outputs}

    # compute the backward pass of the cell net
    if not forward_only:
        backward_ops, backward_mapping = core.GradientRegistry.GetBackwardPass(
            cell_net.Proto().op, inner_outputs_map)
        backward_mapping = {str(k): v for k, v in viewitems(backward_mapping)}

        backward_cell_net = core.Net("RecurrentBackwardStep")
        del backward_cell_net.Proto().op[:]

        if recompute_blobs_on_backward is not None:
            # Insert operators to re-compute the specified blobs.
            # They are added in the same order as for the forward pass, thus
            # the order is correct.
            recompute_blobs_on_backward = {str(b) for b in
                                           recompute_blobs_on_backward}

            for op in cell_net.Proto().op:
                if not recompute_blobs_on_backward.isdisjoint(set(op.output)):
                    backward_cell_net.Proto().op.extend([op])
                    # This fires if other outputs than the declared
                    # are computed by the ops that are recomputed
                    assert set(op.output).issubset(recompute_blobs_on_backward)

        backward_cell_net.Proto().op.extend(backward_ops)
        # compute blobs used but not defined in the backward pass
        backward_ssa, backward_blob_versions = core.get_ssa(
            backward_cell_net.Proto())
        undefined = core.get_undefined_blobs(backward_ssa)

        # also add to the output list the intermediate outputs of fwd_step that
        # are used by backward.
        ssa, blob_versions = core.get_ssa(cell_net.Proto())
        scratches = [
            blob
            for blob, ver in viewitems(blob_versions)
            if (ver > 0 and
                blob in undefined and
                blob not in cell_net.Proto().external_output)
        ]
        backward_cell_net.Proto().external_input.extend(scratches)
        backward_cell_net.Proto().type = 'simple'
    else:
        backward_cell_net = None

    all_inputs = [i[1] for i in inputs] + [
        x[1] for x in initial_cell_inputs] + references
    all_outputs = []

    cell_net.Proto().type = 'rnn'

    # Internal arguments used by RecurrentNetwork operator

    # Links are in the format blob_name, recurrent_states, offset.
    # In the moment t we know that corresponding data block is at
    # t + offset position in the recurrent_states tensor
    forward_links = []
    backward_links = []

    # Aliases are used to expose outputs to external world
    # Format (internal_blob, external_blob, offset)
    # Negative offset stands for going from the end,
    # positive - from the beginning
    aliases = []

    # States held inputs to the cell net
    recurrent_states = []

    for cell_input, _ in initial_cell_inputs:
        cell_input = str(cell_input)
        # Recurrent_states is going to be (T + 1) x ...
        # It stores all inputs and outputs of the cell net over time.
        # Or their gradients in the case of the backward pass.
        state = s(cell_input + "_states")
        states_grad = state + "_grad"
        cell_output = links[str(cell_input)]
        forward_links.append((cell_input, state, 0))
        forward_links.append((cell_output, state, 1))

        aliases.append((state, cell_output + "_all", 1))
        aliases.append((state, cell_output + "_last", -1))
        all_outputs.extend([cell_output + "_all", cell_output + "_last"])

        recurrent_states.append(state)

        if backward_cell_net is not None:
            backward_links.append((cell_output + "_grad", states_grad, 1))
            backward_cell_net.Proto().external_input.append(
                str(cell_output) + "_grad")

            recurrent_input_grad = cell_input + "_grad"
            if not backward_blob_versions.get(recurrent_input_grad, 0):
                # If nobody writes to this recurrent input gradient, we need
                # to make sure it gets to the states grad blob after all.
                # We do this by using backward_links which triggers an alias
                # This logic is being used for example in a SumOp case
                backward_links.append(
                    (backward_mapping[cell_input], states_grad, 0))
            else:
                backward_links.append((cell_input + "_grad", states_grad, 0))

    for input_t, input_blob in inputs:
        forward_links.append((str(input_t), str(input_blob), 0))

    if backward_cell_net is not None:
        for input_t, input_blob in inputs:
            backward_links.append((
                backward_mapping[str(input_t)], str(input_blob) + "_grad", 0
            ))
        backward_cell_net.Proto().external_input.extend(
            cell_net.Proto().external_input)
        backward_cell_net.Proto().external_input.extend(
            cell_net.Proto().external_output)

    def unpack_triple(x):
        if x:
            a, b, c = zip(*x)
            return a, b, c
        return [], [], []

    # Splitting to separate lists so we can pass them to c++
    # where we ensemle them back
    link_internal, link_external, link_offset = unpack_triple(forward_links)
    alias_src, alias_dst, alias_offset = unpack_triple(aliases)

    recurrent_inputs = [str(x[1]) for x in initial_cell_inputs]

    # Make sure that recurrent gradients accumulate with internal gradients
    # (if a blob in the backward_cell_net receives gradient from both an
    # external connection as well as from within the backward_cell_net,
    # those gradients need to be added together, rather than one overwriting
    # the other)
    if backward_cell_net is not None:
        proto = backward_cell_net.Proto()
        operators = []
        while len(proto.op) > 0:
            op = proto.op[-1]
            proto.op.remove(op)
            operators.append(op)
        for op in operators[::-1]:
            proto.op.extend([op])
            for j, output_blob in enumerate(op.output):
                if output_blob in proto.external_input:
                    # In place operation won't cause issues because it takes
                    # existing value of a blob into account
                    if output_blob in op.input:
                        continue
                    output_blob = core.BlobReference(output_blob)
                    accum_blob = output_blob + "_accum"
                    proto.op[-1].output[j] = str(accum_blob)
                    backward_cell_net.Sum(
                        [output_blob, accum_blob],
                        [output_blob],
                    )

    backward_args = {}
    backward_mapping_keys = set(viewkeys(backward_mapping))
    if backward_cell_net is not None:
        backward_link_internal, backward_link_external, backward_link_offset = \
            unpack_triple(backward_links)
        params = [x for x in references if x in backward_mapping_keys]
        param_grads = [
            str(backward_mapping[x])
            for x in references
            if x in backward_mapping_keys
        ]
        if recompute_blobs_on_backward is None:
            recompute_blobs_on_backward = set()
        backward_args = {
            'param': [all_inputs.index(p) for p in params],
            'backward_link_internal': [str(l) for l in backward_link_internal],
            'backward_link_external': [str(l) for l in backward_link_external],
            'backward_link_offset': backward_link_offset,
            'backward_step_net': str(backward_cell_net.Proto()),
            'outputs_with_grads': outputs_with_grads,
            'recompute_blobs_on_backward': [
                str(b) for b in recompute_blobs_on_backward
            ],
            'param_grads': param_grads,
        }

    results = net.RecurrentNetwork(
        all_inputs,
        all_outputs + [s("step_workspaces")],
        alias_src=alias_src,
        alias_dst=[str(a) for a in alias_dst],
        alias_offset=alias_offset,
        recurrent_states=recurrent_states,
        initial_recurrent_state_ids=[
            all_inputs.index(i) for i in recurrent_inputs
        ],
        link_internal=[str(l) for l in link_internal],
        link_external=[str(l) for l in link_external],
        link_offset=link_offset,
        step_net=str(cell_net.Proto()),
        timestep="timestep" if timestep is None else str(timestep),
        **backward_args
    )

    # Restore net type since 'rnn' is not recognized outside RNNs
    cell_net.Proto().type = 'simple'

    # The last output is a list of step workspaces,
    # which is only needed internally for gradient propogation
    return results[:-1]
Пример #3
0
def recurrent_net(net,
                  cell_net,
                  inputs,
                  initial_cell_inputs,
                  links,
                  timestep=None,
                  scope=None,
                  outputs_with_grads=(0, )):
    '''
    net: the main net operator should be added to

    cell_net: cell_net which is executed in a recurrent fasion

    inputs: sequences to be fed into the recurrent net. Currently only one input
    is supported. It has to be in a format T x N x (D1...Dk) where T is lengths
    of the sequence. N is a batch size and (D1...Dk) are the rest of dimentions

    initial_cell_inputs: inputs of the cell_net for the 0 timestamp.
    Format for each input is:
        (cell_net_input_name, external_blob_with_data)

    links: a dictionary from cell_net input names in moment t+1 and
    output names of moment t. Currently we assume that each output becomes
    an input for the next timestep.

    timestep: name of the timestep blob to be used. If not provided "timestep"
    is used.

    scope: Internal blobs are going to be scoped in a format
    <scope_name>/<blob_name>
    If not provided we generate a scope name automatically

    outputs_with_grads : position indices of output blobs which will receive
    error gradient (from outside recurrent network) during backpropagation
    '''
    assert len(inputs) == 1, "Only one input blob is supported so far"

    # Validate scoping
    for einp in cell_net.Proto().external_input:
        assert einp.startswith(CurrentNameScope()), \
            '''
            Cell net external inputs are not properly scoped, use
            AddScopedExternalInputs() when creating them
            '''

    input_blobs = [str(i[0]) for i in inputs]
    initial_input_blobs = [str(x[1]) for x in initial_cell_inputs]
    op_name = net.NextName('recurrent')

    def s(name):
        # We have to manually scope due to our internal/external blob
        # relationships.
        scope_name = op_name if scope is None else scope
        return "{}/{}".format(str(scope_name), str(name))

    # determine inputs that are considered to be references
    # it is those that are not referred to in inputs or initial_cell_inputs
    known_inputs = map(str, input_blobs + initial_input_blobs)
    known_inputs += [str(x[0]) for x in initial_cell_inputs]
    if timestep is not None:
        known_inputs.append(str(timestep))
    references = [
        core.BlobReference(b) for b in cell_net.Proto().external_input
        if b not in known_inputs
    ]

    inner_outputs = list(cell_net.Proto().external_output)
    # These gradients are expected to be available during the backward pass
    inner_outputs_map = {o: o + '_grad' for o in inner_outputs}

    # compute the backward pass of the cell net
    backward_ops, backward_mapping = core.GradientRegistry.GetBackwardPass(
        cell_net.Proto().op, inner_outputs_map)
    backward_mapping = {str(k): str(v) for k, v in backward_mapping.items()}
    backward_cell_net = core.Net("RecurrentBackwardStep")

    del backward_cell_net.Proto().op[:]
    backward_cell_net.Proto().op.extend(backward_ops)

    # compute blobs used but not defined in the backward pass
    ssa, _ = core.get_ssa(backward_cell_net.Proto())
    undefined = core.get_undefined_blobs(ssa)

    # also add to the output list the intermediate outputs of fwd_step that
    # are used by backward.
    ssa, blob_versions = core.get_ssa(cell_net.Proto())
    scratches = [
        blob for (blob, ver) in blob_versions.items() if ver > 0
        and blob in undefined and blob not in cell_net.Proto().external_output
    ]
    backward_cell_net.Proto().external_input.extend(scratches)

    all_inputs = [i[1] for i in inputs] + [x[1] for x in initial_cell_inputs
                                           ] + references
    all_outputs = []

    cell_net.Proto().type = 'simple'
    backward_cell_net.Proto().type = 'simple'

    # Internal arguments used by RecurrentNetwork operator

    # Links are in the format blob_name, recurrent_states, offset.
    # In the moment t we know that corresponding data block is at
    # t + offset position in the recurrent_states tensor
    forward_links = []
    backward_links = []

    # Aliases are used to expose outputs to external world
    # Format (internal_blob, external_blob, offset)
    # Negative offset stands for going from the end,
    # positive - from the beginning
    aliases = []

    # States held inputs to the cell net
    recurrent_states = []

    for cell_input, _ in initial_cell_inputs:
        cell_input = str(cell_input)
        # Recurrent_states is going to be (T + 1) x ...
        # It stores all inputs and outputs of the cell net over time.
        # Or their gradients in the case of the backward pass.
        state = s(cell_input + "_states")
        states_grad = state + "_grad"
        cell_output = links[str(cell_input)]
        forward_links.append((cell_input, state, 0))
        forward_links.append((cell_output, state, 1))
        backward_links.append((cell_input + "_grad", states_grad, 0))
        backward_links.append((cell_output + "_grad", states_grad, 1))

        backward_cell_net.Proto().external_input.append(
            str(cell_output) + "_grad")
        aliases.append((state, cell_output + "_all", 1))
        aliases.append((state, cell_output + "_last", -1))
        all_outputs.extend([cell_output + "_all", cell_output + "_last"])

        recurrent_states.append(state)

    for input_t, input_blob in inputs:
        forward_links.append((str(input_t), str(input_blob), 0))
        backward_links.append(
            (backward_mapping[str(input_t)], str(input_blob) + "_grad", 0))
    backward_cell_net.Proto().external_input.extend(
        cell_net.Proto().external_input)
    backward_cell_net.Proto().external_input.extend(
        cell_net.Proto().external_output)

    def unpack_triple(x):
        if x:
            a, b, c = zip(*x)
            return a, b, c
        return [], [], []

    # Splitting to separate lists so we can pass them to c++
    # where we ensemle them back
    link_internal, link_external, link_offset = unpack_triple(forward_links)
    backward_link_internal, backward_link_external, backward_link_offset = \
        unpack_triple(backward_links)
    alias_src, alias_dst, alias_offset = unpack_triple(aliases)

    params = [x for x in references if x in backward_mapping.keys()]
    recurrent_inputs = [str(x[1]) for x in initial_cell_inputs]

    global _workspace_seq
    results = net.RecurrentNetwork(
        all_inputs,
        all_outputs + [s("step_workspaces_{}".format(_workspace_seq))],
        param=map(all_inputs.index, params),
        alias_src=alias_src,
        alias_dst=map(str, alias_dst),
        alias_offset=alias_offset,
        recurrent_states=recurrent_states,
        initial_recurrent_state_ids=map(all_inputs.index, recurrent_inputs),
        link_internal=map(str, link_internal),
        link_external=map(str, link_external),
        link_offset=link_offset,
        backward_link_internal=map(str, backward_link_internal),
        backward_link_external=map(str, backward_link_external),
        backward_link_offset=backward_link_offset,
        step_net=str(cell_net.Proto()),
        backward_step_net=str(backward_cell_net.Proto()),
        timestep="timestep" if timestep is None else str(timestep),
        outputs_with_grads=outputs_with_grads,
    )
    _workspace_seq += 1
    # The last output is a list of step workspaces,
    # which is only needed internally for gradient propogation
    return results[:-1]