示例#1
0
def get_external_blob_names(net, lexical_scope):
    """
    Returns a set of blobs a given net depends on and a set of
    output blobs that are written by the net
    Inputs:
        net - net to return input/output blobs for;
        lexical_scope - all external blob names visible to the net
    """
    net_proto = net.Proto()
    net_ssa, _ = core.get_ssa(net_proto)
    input_names = core.get_undefined_blobs(net_ssa)
    if net_proto.external_input:
        input_names |= set(net_proto.external_input)

    output_names = set()
    if net_proto.external_output:
        output_names = set(net_proto.external_output)
    for op in net_proto.op:
        for output in op.output:
            if output in lexical_scope:
                output_names.add(output)

    return input_names, output_names
示例#2
0
def get_external_blob_names(net, lexical_scope):
    """
    Returns a set of blobs a given net depends on and a set of
    output blobs that are written by the net
    Inputs:
        net - net to return input/output blobs for;
        lexical_scope - all external blob names visible to the net
    """
    # Use the blobs that are actually read/written to as external inputs/outputs
    net_proto = net.Proto()
    net_ssa, _ = core.get_ssa(net_proto)
    input_names = core.get_undefined_blobs(net_ssa)
    for input_name in input_names:
        assert str(input_name) in lexical_scope, \
            "Input blob " + input_name + " is undefined"

    output_names = set()
    for op in net_proto.op:
        for output in op.output:
            if output in lexical_scope:
                output_names.add(output)

    return input_names, output_names
示例#3
0
def get_external_blob_names(net, lexical_scope):
    """
    Returns a set of blobs a given net depends on and a set of
    output blobs that are written by the net
    Inputs:
        net - net to return input/output blobs for;
        lexical_scope - all external blob names visible to the net
    """
    # Use the blobs that are actually read/written to as external inputs/outputs
    net_proto = net.Proto()
    net_ssa, _ = core.get_ssa(net_proto)
    input_names = core.get_undefined_blobs(net_ssa)
    for input_name in input_names:
        assert str(input_name) in lexical_scope, \
            "Input blob " + input_name + " is undefined"

    output_names = set()
    for op in net_proto.op:
        for output in op.output:
            if output in lexical_scope:
                output_names.add(output)

    return input_names, output_names
示例#4
0
def recurrent_net(
        net, cell_net, inputs, initial_cell_inputs,
        links, timestep=None, scope=None, outputs_with_grads=(0,),
        recompute_blobs_on_backward=None, forward_only=False,
):
    '''
    net: the main net operator should be added to

    cell_net: cell_net which is executed in a recurrent fasion

    inputs: sequences to be fed into the recurrent net. Currently only one input
    is supported. It has to be in a format T x N x (D1...Dk) where T is lengths
    of the sequence. N is a batch size and (D1...Dk) are the rest of dimentions

    initial_cell_inputs: inputs of the cell_net for the 0 timestamp.
    Format for each input is:
        (cell_net_input_name, external_blob_with_data)

    links: a dictionary from cell_net input names in moment t+1 and
    output names of moment t. Currently we assume that each output becomes
    an input for the next timestep.

    timestep: name of the timestep blob to be used. If not provided "timestep"
    is used.

    scope: Internal blobs are going to be scoped in a format
    <scope_name>/<blob_name>
    If not provided we generate a scope name automatically

    outputs_with_grads : position indices of output blobs which will receive
    error gradient (from outside recurrent network) during backpropagation

    recompute_blobs_on_backward: specify a list of blobs that will be
                 recomputed for backward pass, and thus need not to be
                 stored for each forward timestep.

    forward_only: if True, only forward steps are executed
    '''
    assert len(inputs) == 1, "Only one input blob is supported so far"

    # Validate scoping
    for einp in cell_net.Proto().external_input:
        assert einp.startswith(CurrentNameScope()), \
            '''
            Cell net external inputs are not properly scoped, use
            AddScopedExternalInputs() when creating them
            '''

    input_blobs = [str(i[0]) for i in inputs]
    initial_input_blobs = [str(x[1]) for x in initial_cell_inputs]
    op_name = net.NextName('recurrent')

    def s(name):
        # We have to manually scope due to our internal/external blob
        # relationships.
        scope_name = op_name if scope is None else scope
        return "{}/{}".format(str(scope_name), str(name))

    # determine inputs that are considered to be references
    # it is those that are not referred to in inputs or initial_cell_inputs
    known_inputs = [str(b) for b in input_blobs + initial_input_blobs]
    known_inputs += [str(x[0]) for x in initial_cell_inputs]
    if timestep is not None:
        known_inputs.append(str(timestep))
    references = [
        core.BlobReference(b) for b in cell_net.Proto().external_input
        if b not in known_inputs]

    inner_outputs = list(cell_net.Proto().external_output)
    # These gradients are expected to be available during the backward pass
    inner_outputs_map = {o: o + '_grad' for o in inner_outputs}

    # compute the backward pass of the cell net
    if not forward_only:
        backward_ops, backward_mapping = core.GradientRegistry.GetBackwardPass(
            cell_net.Proto().op, inner_outputs_map)
        backward_mapping = {str(k): v for k, v in viewitems(backward_mapping)}

        backward_cell_net = core.Net("RecurrentBackwardStep")
        del backward_cell_net.Proto().op[:]

        if recompute_blobs_on_backward is not None:
            # Insert operators to re-compute the specified blobs.
            # They are added in the same order as for the forward pass, thus
            # the order is correct.
            recompute_blobs_on_backward = {str(b) for b in
                                           recompute_blobs_on_backward}

            for op in cell_net.Proto().op:
                if not recompute_blobs_on_backward.isdisjoint(set(op.output)):
                    backward_cell_net.Proto().op.extend([op])
                    # This fires if other outputs than the declared
                    # are computed by the ops that are recomputed
                    assert set(op.output).issubset(recompute_blobs_on_backward)

        backward_cell_net.Proto().op.extend(backward_ops)
        # compute blobs used but not defined in the backward pass
        backward_ssa, backward_blob_versions = core.get_ssa(
            backward_cell_net.Proto())
        undefined = core.get_undefined_blobs(backward_ssa)

        # also add to the output list the intermediate outputs of fwd_step that
        # are used by backward.
        ssa, blob_versions = core.get_ssa(cell_net.Proto())
        scratches = [
            blob
            for blob, ver in viewitems(blob_versions)
            if (ver > 0 and
                blob in undefined and
                blob not in cell_net.Proto().external_output)
        ]
        backward_cell_net.Proto().external_input.extend(scratches)
        backward_cell_net.Proto().type = 'simple'
    else:
        backward_cell_net = None

    all_inputs = [i[1] for i in inputs] + [
        x[1] for x in initial_cell_inputs] + references
    all_outputs = []

    cell_net.Proto().type = 'rnn'

    # Internal arguments used by RecurrentNetwork operator

    # Links are in the format blob_name, recurrent_states, offset.
    # In the moment t we know that corresponding data block is at
    # t + offset position in the recurrent_states tensor
    forward_links = []
    backward_links = []

    # Aliases are used to expose outputs to external world
    # Format (internal_blob, external_blob, offset)
    # Negative offset stands for going from the end,
    # positive - from the beginning
    aliases = []

    # States held inputs to the cell net
    recurrent_states = []

    for cell_input, _ in initial_cell_inputs:
        cell_input = str(cell_input)
        # Recurrent_states is going to be (T + 1) x ...
        # It stores all inputs and outputs of the cell net over time.
        # Or their gradients in the case of the backward pass.
        state = s(cell_input + "_states")
        states_grad = state + "_grad"
        cell_output = links[str(cell_input)]
        forward_links.append((cell_input, state, 0))
        forward_links.append((cell_output, state, 1))

        aliases.append((state, cell_output + "_all", 1))
        aliases.append((state, cell_output + "_last", -1))
        all_outputs.extend([cell_output + "_all", cell_output + "_last"])

        recurrent_states.append(state)

        if backward_cell_net is not None:
            backward_links.append((cell_output + "_grad", states_grad, 1))
            backward_cell_net.Proto().external_input.append(
                str(cell_output) + "_grad")

            recurrent_input_grad = cell_input + "_grad"
            if not backward_blob_versions.get(recurrent_input_grad, 0):
                # If nobody writes to this recurrent input gradient, we need
                # to make sure it gets to the states grad blob after all.
                # We do this by using backward_links which triggers an alias
                # This logic is being used for example in a SumOp case
                backward_links.append(
                    (backward_mapping[cell_input], states_grad, 0))
            else:
                backward_links.append((cell_input + "_grad", states_grad, 0))

    for input_t, input_blob in inputs:
        forward_links.append((str(input_t), str(input_blob), 0))

    if backward_cell_net is not None:
        for input_t, input_blob in inputs:
            backward_links.append((
                backward_mapping[str(input_t)], str(input_blob) + "_grad", 0
            ))
        backward_cell_net.Proto().external_input.extend(
            cell_net.Proto().external_input)
        backward_cell_net.Proto().external_input.extend(
            cell_net.Proto().external_output)

    def unpack_triple(x):
        if x:
            a, b, c = zip(*x)
            return a, b, c
        return [], [], []

    # Splitting to separate lists so we can pass them to c++
    # where we ensemle them back
    link_internal, link_external, link_offset = unpack_triple(forward_links)
    alias_src, alias_dst, alias_offset = unpack_triple(aliases)

    recurrent_inputs = [str(x[1]) for x in initial_cell_inputs]

    # Make sure that recurrent gradients accumulate with internal gradients
    # (if a blob in the backward_cell_net receives gradient from both an
    # external connection as well as from within the backward_cell_net,
    # those gradients need to be added together, rather than one overwriting
    # the other)
    if backward_cell_net is not None:
        proto = backward_cell_net.Proto()
        operators = []
        while len(proto.op) > 0:
            op = proto.op[-1]
            proto.op.remove(op)
            operators.append(op)
        for op in operators[::-1]:
            proto.op.extend([op])
            for j, output_blob in enumerate(op.output):
                if output_blob in proto.external_input:
                    # In place operation won't cause issues because it takes
                    # existing value of a blob into account
                    if output_blob in op.input:
                        continue
                    output_blob = core.BlobReference(output_blob)
                    accum_blob = output_blob + "_accum"
                    proto.op[-1].output[j] = str(accum_blob)
                    backward_cell_net.Sum(
                        [output_blob, accum_blob],
                        [output_blob],
                    )

    backward_args = {}
    backward_mapping_keys = set(viewkeys(backward_mapping))
    if backward_cell_net is not None:
        backward_link_internal, backward_link_external, backward_link_offset = \
            unpack_triple(backward_links)
        params = [x for x in references if x in backward_mapping_keys]
        param_grads = [
            str(backward_mapping[x])
            for x in references
            if x in backward_mapping_keys
        ]
        if recompute_blobs_on_backward is None:
            recompute_blobs_on_backward = set()
        backward_args = {
            'param': [all_inputs.index(p) for p in params],
            'backward_link_internal': [str(l) for l in backward_link_internal],
            'backward_link_external': [str(l) for l in backward_link_external],
            'backward_link_offset': backward_link_offset,
            'backward_step_net': str(backward_cell_net.Proto()),
            'outputs_with_grads': outputs_with_grads,
            'recompute_blobs_on_backward': [
                str(b) for b in recompute_blobs_on_backward
            ],
            'param_grads': param_grads,
        }

    results = net.RecurrentNetwork(
        all_inputs,
        all_outputs + [s("step_workspaces")],
        alias_src=alias_src,
        alias_dst=[str(a) for a in alias_dst],
        alias_offset=alias_offset,
        recurrent_states=recurrent_states,
        initial_recurrent_state_ids=[
            all_inputs.index(i) for i in recurrent_inputs
        ],
        link_internal=[str(l) for l in link_internal],
        link_external=[str(l) for l in link_external],
        link_offset=link_offset,
        step_net=str(cell_net.Proto()),
        timestep="timestep" if timestep is None else str(timestep),
        **backward_args
    )

    # Restore net type since 'rnn' is not recognized outside RNNs
    cell_net.Proto().type = 'simple'

    # The last output is a list of step workspaces,
    # which is only needed internally for gradient propogation
    return results[:-1]
示例#5
0
def recurrent_net(
        net, cell_net, inputs, initial_cell_inputs,
        links, scratch_sizes,
        timestep=None, scope=None
):
    '''
    net: the main net operator should be added to

    cell_net: cell_net which is executed in a recurrent fasion

    inputs: sequences to be fed into the recurrent net. Currently only one input
    is supported. It has to be in a format T x N x (D1...Dk) where T is lengths
    of the sequence. N is a batch size and (D1...Dk) are the rest of dimentions

    initial_cell_inputs: inputs of the cell_net for the 0 timestamp.
    Format for each input is:
        (cell_net_input_name, external_blob_with_data)

    links: a dictionary from cell_net input names in moment t+1 and
    output names of moment t. Currently we assume that each output becomes
    an input for the next timestep.

    scratch_sizes: sizes of the scratch blobs. Scratch blobs are those
    intermidiate blobs of the cell_net which are used in backward pass.
    We use sizes iformation to preallocate memory for them over time.

    For example in case of LSTM we have FC -> Sum ->LSTMUnit sequence of
    operations in each iteration of the cell net. Output of Sum is an
    intermidiate blob. Also it is going to be part of the backward pass.
    Thus it is a scratch blob size of which we must to pvovide.

    timestep: name of the timestep blob to be used. If not provided "timestep"
    is used.

    scope: Internal blobs are going to be scoped in a format
    <scope_name>/<blob_name>
    If not provided we generate a scope name automatically
    '''
    assert len(inputs) == 1, "Only one input blob is supported so far"

    input_blobs = [str(i[0]) for i in inputs]
    initial_input_blobs = [str(x[1]) for x in initial_cell_inputs]
    op_name = net.NextName('recurrent')

    def s(name):
        # We have to manually scope due to our internal/external blob
        # relationships.
        scope_name = op_name if scope is None else scope
        return "{}/{}".format(str(scope_name), str(name))

    # determine inputs that are considered to be references
    # it is those that are not referred to in inputs or initial_cell_inputs
    known_inputs = map(str, input_blobs + initial_input_blobs)
    known_inputs += [str(x[0]) for x in initial_cell_inputs]
    if timestep is not None:
        known_inputs.append(str(timestep))
    references = [
        b for b in cell_net.Proto().external_input
        if b not in known_inputs]

    inner_outputs = list(cell_net.Proto().external_output)
    # These gradients are expected to be available during the backward pass
    inner_outputs_map = {o: o + '_grad' for o in inner_outputs}

    # compute the backward pass of the cell net
    backward_ops, backward_mapping = core.GradientRegistry.GetBackwardPass(
        cell_net.Proto().op, inner_outputs_map)
    backward_mapping = {str(k): str(v) for k, v in backward_mapping.items()}
    backward_cell_net = core.Net("RecurrentBackwardStep")

    del backward_cell_net.Proto().op[:]
    backward_cell_net.Proto().op.extend(backward_ops)

    # compute blobs used but not defined in the backward pass
    ssa, _ = core.get_ssa(backward_cell_net.Proto())
    undefined = core.get_undefined_blobs(ssa)

    # also add to the output list the intermediate outputs of fwd_step that
    # are used by backward.
    ssa, blob_versions = core.get_ssa(cell_net.Proto())
    scratches = [
        blob for (blob, ver) in blob_versions.items()
        if ver > 0 and
        blob in undefined and
        blob not in cell_net.Proto().external_output]

    all_inputs = [i[1] for i in inputs] + [
        x[1] for x in initial_cell_inputs] + references
    all_outputs = []

    cell_net.Proto().type = 'simple'
    backward_cell_net.Proto().type = 'simple'

    # Internal arguments used by RecurrentNetwork operator

    # Links are in the format blob_name, recurrent_states, offset.
    # In the moment t we know that corresponding data block is at
    # t + offset position in the recurrent_states tensor
    forward_links = []
    backward_links = []

    # Aliases are used to expose outputs to external world
    # Format (internal_blob, external_blob, offset)
    # Negative offset stands for going from the end,
    # positive - from the beginning
    aliases = []
    backward_aliases = []

    # States held inputs to the cell net
    recurrent_states = []

    # a map from gradient blob name to blob with its value over time
    grad_to_state = {}

    # A mapping from a blob to its gradient state blob

    for cell_input, _ in initial_cell_inputs:
        cell_input = str(cell_input)
        # Recurrent_states is going to be (T + 1) x ...
        # It stores all inputs and outputs of the cell net over time.
        # Or their gradients in the case of the backward pass.
        state = s(cell_input + "_states")
        states_grad = state + "_grad"
        cell_output = links[str(cell_input)]
        forward_links.append((cell_input, state, 0))
        forward_links.append((cell_output, state, 1))
        backward_links.append((cell_input + "_grad", states_grad, 0))
        backward_links.append((cell_output + "_grad", states_grad, 1))

        backward_cell_net.Proto().external_input.append(
            str(cell_output) + "_grad")
        aliases.append((state, cell_output + "_last", -1))
        aliases.append((state, cell_output + "_all", 1))
        all_outputs.extend([cell_output + "_all", cell_output + "_last"])

        recurrent_states.append(state)

    for scratch in scratches:
        # no scoping as scratches should be already scoped
        forward_links.append((scratch, scratch + "_states", 0))
        grad_blob = scratch + "_grad"
        states_grad_blob = scratch + "_states_grad"
        backward_links.append((grad_blob, states_grad_blob, 0))
        backward_cell_net.Proto().external_input.append(scratch)
        grad_to_state[grad_blob] = states_grad_blob

    input_gradient_ids = []
    for input_id, (input_t, input_blob) in enumerate(inputs):
        forward_links.append((str(input_t), str(input_blob), 0))

        input_blob_grad = str(input_blob) + "_grad"
        if backward_mapping[str(input_t)] != str(input_t) + "_grad":
            # Some scratch (internal blob) ends up being an input gradient
            # So we avoid extra copy and reuse it by applying this alias
            backward_aliases.append((
                grad_to_state[backward_mapping[str(input_t)]],
                input_blob_grad,
                0
            ))
        else:
            # This is a general case - we have to explicitly create input
            # gradient blob as it doesn't match any of internal gradients
            backward_links.append(
                (str(input_t) + "_grad", input_blob_grad, 0))
            input_gradient_ids.append(input_id)

    backward_cell_net.Proto().external_input.extend(
        cell_net.Proto().external_input)
    backward_cell_net.Proto().external_input.extend(
        cell_net.Proto().external_output)

    def unpack_triple(x):
        if x:
            a, b, c = zip(*x)
            return a, b, c
        return [], [], []

    # Splitting to separate lists so we can pass them to c++
    # where we ensemle them back
    link_internal, link_external, link_offset = unpack_triple(forward_links)
    backward_link_internal, backward_link_external, backward_link_offset = \
        unpack_triple(backward_links)
    alias_src, alias_dst, alias_offset = unpack_triple(aliases)
    backward_alias_src, backward_alias_dst, backward_alias_offset = \
        unpack_triple(backward_aliases)

    params = [x for x in references if x in backward_mapping.keys()]

    recurrent_inputs = [str(x[1]) for x in initial_cell_inputs]

    return net.RecurrentNetwork(
        all_inputs,
        all_outputs,
        param=map(all_inputs.index, params),
        alias_src=alias_src,
        alias_dst=map(str, alias_dst),
        alias_offset=alias_offset,
        recurrent_states=recurrent_states,
        recurrent_inputs=recurrent_inputs,
        recurrent_input_ids=map(all_inputs.index, recurrent_inputs),
        link_internal=map(str, link_internal),
        link_external=map(str, link_external),
        link_offset=link_offset,
        backward_link_internal=map(str, backward_link_internal),
        backward_link_external=map(str, backward_link_external),
        backward_link_offset=backward_link_offset,
        backward_alias_src=backward_alias_src,
        backward_alias_dst=backward_alias_dst,
        backward_alias_offset=backward_alias_offset,
        scratch=[sc + "_states" for sc in scratches],
        backward_scratch=[sc + "_states_grad" for sc in scratches],
        scratch_sizes=scratch_sizes,
        step_net=str(cell_net.Proto()),
        backward_step_net=str(backward_cell_net.Proto()),
        timestep="timestep" if timestep is None else str(timestep),
        input_gradient_ids=input_gradient_ids,
    )
示例#6
0
def recurrent_net(
        net,
        cell_net,
        inputs,
        initial_cell_inputs,
        links,
        timestep=None,
        scope=None,
        outputs_with_grads=(0, ),
        recompute_blobs_on_backward=None,
):
    '''
    net: the main net operator should be added to

    cell_net: cell_net which is executed in a recurrent fasion

    inputs: sequences to be fed into the recurrent net. Currently only one input
    is supported. It has to be in a format T x N x (D1...Dk) where T is lengths
    of the sequence. N is a batch size and (D1...Dk) are the rest of dimentions

    initial_cell_inputs: inputs of the cell_net for the 0 timestamp.
    Format for each input is:
        (cell_net_input_name, external_blob_with_data)

    links: a dictionary from cell_net input names in moment t+1 and
    output names of moment t. Currently we assume that each output becomes
    an input for the next timestep.

    timestep: name of the timestep blob to be used. If not provided "timestep"
    is used.

    scope: Internal blobs are going to be scoped in a format
    <scope_name>/<blob_name>
    If not provided we generate a scope name automatically

    outputs_with_grads : position indices of output blobs which will receive
    error gradient (from outside recurrent network) during backpropagation

    recompute_blobs_on_backward: specify a list of blobs that will be
                 recomputed for backward pass, and thus need not to be
                 stored for each forward timestep.
    '''
    assert len(inputs) == 1, "Only one input blob is supported so far"

    # Validate scoping
    for einp in cell_net.Proto().external_input:
        assert einp.startswith(CurrentNameScope()), \
            '''
            Cell net external inputs are not properly scoped, use
            AddScopedExternalInputs() when creating them
            '''

    input_blobs = [str(i[0]) for i in inputs]
    initial_input_blobs = [str(x[1]) for x in initial_cell_inputs]
    op_name = net.NextName('recurrent')

    def s(name):
        # We have to manually scope due to our internal/external blob
        # relationships.
        scope_name = op_name if scope is None else scope
        return "{}/{}".format(str(scope_name), str(name))

    # determine inputs that are considered to be references
    # it is those that are not referred to in inputs or initial_cell_inputs
    known_inputs = map(str, input_blobs + initial_input_blobs)
    known_inputs += [str(x[0]) for x in initial_cell_inputs]
    if timestep is not None:
        known_inputs.append(str(timestep))
    references = [
        core.BlobReference(b) for b in cell_net.Proto().external_input
        if b not in known_inputs
    ]

    inner_outputs = list(cell_net.Proto().external_output)
    # These gradients are expected to be available during the backward pass
    inner_outputs_map = {o: o + '_grad' for o in inner_outputs}

    # compute the backward pass of the cell net
    backward_ops, backward_mapping = core.GradientRegistry.GetBackwardPass(
        cell_net.Proto().op, inner_outputs_map)
    backward_mapping = {str(k): v for k, v in backward_mapping.items()}
    backward_cell_net = core.Net("RecurrentBackwardStep")
    del backward_cell_net.Proto().op[:]

    if recompute_blobs_on_backward is not None:
        # Insert operators to re-compute the specified blobs.
        # They are added in the same order as for the forward pass, thus
        # the order is correct.
        recompute_blobs_on_backward = set(
            [str(b) for b in recompute_blobs_on_backward])
        for op in cell_net.Proto().op:
            if not recompute_blobs_on_backward.isdisjoint(set(op.output)):
                backward_cell_net.Proto().op.extend([op])
                assert set(op.output).issubset(recompute_blobs_on_backward), \
                       'Outputs {} are output by op but not recomputed: {}'.format(
                            set(op.output) - recompute_blobs_on_backward,
                            op
                       )
    else:
        recompute_blobs_on_backward = set()

    backward_cell_net.Proto().op.extend(backward_ops)
    # compute blobs used but not defined in the backward pass
    backward_ssa, backward_blob_versions = core.get_ssa(
        backward_cell_net.Proto())
    undefined = core.get_undefined_blobs(backward_ssa)

    # also add to the output list the intermediate outputs of fwd_step that
    # are used by backward.
    ssa, blob_versions = core.get_ssa(cell_net.Proto())
    scratches = [
        blob for (blob, ver) in blob_versions.items() if ver > 0
        and blob in undefined and blob not in cell_net.Proto().external_output
    ]
    backward_cell_net.Proto().external_input.extend(scratches)

    all_inputs = [i[1] for i in inputs] + [x[1] for x in initial_cell_inputs
                                           ] + references
    all_outputs = []

    cell_net.Proto().type = 'simple'
    backward_cell_net.Proto().type = 'simple'

    # Internal arguments used by RecurrentNetwork operator

    # Links are in the format blob_name, recurrent_states, offset.
    # In the moment t we know that corresponding data block is at
    # t + offset position in the recurrent_states tensor
    forward_links = []
    backward_links = []

    # Aliases are used to expose outputs to external world
    # Format (internal_blob, external_blob, offset)
    # Negative offset stands for going from the end,
    # positive - from the beginning
    aliases = []

    # States held inputs to the cell net
    recurrent_states = []

    for cell_input, _ in initial_cell_inputs:
        cell_input = str(cell_input)
        # Recurrent_states is going to be (T + 1) x ...
        # It stores all inputs and outputs of the cell net over time.
        # Or their gradients in the case of the backward pass.
        state = s(cell_input + "_states")
        states_grad = state + "_grad"
        cell_output = links[str(cell_input)]
        forward_links.append((cell_input, state, 0))
        forward_links.append((cell_output, state, 1))
        backward_links.append((cell_output + "_grad", states_grad, 1))

        backward_cell_net.Proto().external_input.append(
            str(cell_output) + "_grad")
        aliases.append((state, cell_output + "_all", 1))
        aliases.append((state, cell_output + "_last", -1))
        all_outputs.extend([cell_output + "_all", cell_output + "_last"])

        recurrent_states.append(state)

        recurrent_input_grad = cell_input + "_grad"
        if not backward_blob_versions.get(recurrent_input_grad, 0):
            # If nobody writes to this recurrent input gradient, we need
            # to make sure it gets to the states grad blob after all.
            # We do this by using backward_links which triggers an alias
            # This logic is being used for example in a SumOp case
            backward_links.append(
                (backward_mapping[cell_input], states_grad, 0))
        else:
            backward_links.append((cell_input + "_grad", states_grad, 0))

    for reference in references:
        # Similar to above, in a case of a SumOp we need to write our parameter
        # gradient to an external blob. In this case we can be sure that
        # reference + "_grad" is a correct parameter name as we know how
        # RecurrentNetworkOp gradient schema looks like.
        reference_grad = reference + "_grad"
        if (reference in backward_mapping
                and reference_grad != str(backward_mapping[reference])):
            # We can use an Alias because after each timestep
            # RNN op adds value from reference_grad into and _acc blob
            # which accumulates gradients for corresponding parameter accross
            # timesteps. Then in the end of RNN op these two are being
            # swaped and reference_grad blob becomes a real blob instead of
            # being an alias
            backward_cell_net.Alias(backward_mapping[reference],
                                    reference_grad)

    for input_t, input_blob in inputs:
        forward_links.append((str(input_t), str(input_blob), 0))
        backward_links.append(
            (backward_mapping[str(input_t)], str(input_blob) + "_grad", 0))
    backward_cell_net.Proto().external_input.extend(
        cell_net.Proto().external_input)
    backward_cell_net.Proto().external_input.extend(
        cell_net.Proto().external_output)

    def unpack_triple(x):
        if x:
            a, b, c = zip(*x)
            return a, b, c
        return [], [], []

    # Splitting to separate lists so we can pass them to c++
    # where we ensemle them back
    link_internal, link_external, link_offset = unpack_triple(forward_links)
    backward_link_internal, backward_link_external, backward_link_offset = \
        unpack_triple(backward_links)
    alias_src, alias_dst, alias_offset = unpack_triple(aliases)

    params = [x for x in references if x in backward_mapping.keys()]
    recurrent_inputs = [str(x[1]) for x in initial_cell_inputs]

    global _workspace_seq
    results = net.RecurrentNetwork(
        all_inputs,
        all_outputs + [s("step_workspaces")],
        param=map(all_inputs.index, params),
        alias_src=alias_src,
        alias_dst=map(str, alias_dst),
        alias_offset=alias_offset,
        recurrent_states=recurrent_states,
        initial_recurrent_state_ids=map(all_inputs.index, recurrent_inputs),
        link_internal=map(str, link_internal),
        link_external=map(str, link_external),
        link_offset=link_offset,
        backward_link_internal=map(str, backward_link_internal),
        backward_link_external=map(str, backward_link_external),
        backward_link_offset=backward_link_offset,
        step_net=str(cell_net.Proto()),
        backward_step_net=str(backward_cell_net.Proto()),
        timestep="timestep" if timestep is None else str(timestep),
        outputs_with_grads=outputs_with_grads,
        recompute_blobs_on_backward=map(str, recompute_blobs_on_backward))
    # The last output is a list of step workspaces,
    # which is only needed internally for gradient propogation
    return results[:-1]
示例#7
0
def recurrent_net(net,
                  cell_net,
                  inputs,
                  initial_cell_inputs,
                  links,
                  timestep=None,
                  scope=None,
                  outputs_with_grads=(0, )):
    '''
    net: the main net operator should be added to

    cell_net: cell_net which is executed in a recurrent fasion

    inputs: sequences to be fed into the recurrent net. Currently only one input
    is supported. It has to be in a format T x N x (D1...Dk) where T is lengths
    of the sequence. N is a batch size and (D1...Dk) are the rest of dimentions

    initial_cell_inputs: inputs of the cell_net for the 0 timestamp.
    Format for each input is:
        (cell_net_input_name, external_blob_with_data)

    links: a dictionary from cell_net input names in moment t+1 and
    output names of moment t. Currently we assume that each output becomes
    an input for the next timestep.

    timestep: name of the timestep blob to be used. If not provided "timestep"
    is used.

    scope: Internal blobs are going to be scoped in a format
    <scope_name>/<blob_name>
    If not provided we generate a scope name automatically

    outputs_with_grads : position indices of output blobs which will receive
    error gradient (from outside recurrent network) during backpropagation
    '''
    assert len(inputs) == 1, "Only one input blob is supported so far"

    # Validate scoping
    for einp in cell_net.Proto().external_input:
        assert einp.startswith(CurrentNameScope()), \
            '''
            Cell net external inputs are not properly scoped, use
            AddScopedExternalInputs() when creating them
            '''

    input_blobs = [str(i[0]) for i in inputs]
    initial_input_blobs = [str(x[1]) for x in initial_cell_inputs]
    op_name = net.NextName('recurrent')

    def s(name):
        # We have to manually scope due to our internal/external blob
        # relationships.
        scope_name = op_name if scope is None else scope
        return "{}/{}".format(str(scope_name), str(name))

    # determine inputs that are considered to be references
    # it is those that are not referred to in inputs or initial_cell_inputs
    known_inputs = map(str, input_blobs + initial_input_blobs)
    known_inputs += [str(x[0]) for x in initial_cell_inputs]
    if timestep is not None:
        known_inputs.append(str(timestep))
    references = [
        core.BlobReference(b) for b in cell_net.Proto().external_input
        if b not in known_inputs
    ]

    inner_outputs = list(cell_net.Proto().external_output)
    # These gradients are expected to be available during the backward pass
    inner_outputs_map = {o: o + '_grad' for o in inner_outputs}

    # compute the backward pass of the cell net
    backward_ops, backward_mapping = core.GradientRegistry.GetBackwardPass(
        cell_net.Proto().op, inner_outputs_map)
    backward_mapping = {str(k): str(v) for k, v in backward_mapping.items()}
    backward_cell_net = core.Net("RecurrentBackwardStep")

    del backward_cell_net.Proto().op[:]
    backward_cell_net.Proto().op.extend(backward_ops)

    # compute blobs used but not defined in the backward pass
    ssa, _ = core.get_ssa(backward_cell_net.Proto())
    undefined = core.get_undefined_blobs(ssa)

    # also add to the output list the intermediate outputs of fwd_step that
    # are used by backward.
    ssa, blob_versions = core.get_ssa(cell_net.Proto())
    scratches = [
        blob for (blob, ver) in blob_versions.items() if ver > 0
        and blob in undefined and blob not in cell_net.Proto().external_output
    ]
    backward_cell_net.Proto().external_input.extend(scratches)

    all_inputs = [i[1] for i in inputs] + [x[1] for x in initial_cell_inputs
                                           ] + references
    all_outputs = []

    cell_net.Proto().type = 'simple'
    backward_cell_net.Proto().type = 'simple'

    # Internal arguments used by RecurrentNetwork operator

    # Links are in the format blob_name, recurrent_states, offset.
    # In the moment t we know that corresponding data block is at
    # t + offset position in the recurrent_states tensor
    forward_links = []
    backward_links = []

    # Aliases are used to expose outputs to external world
    # Format (internal_blob, external_blob, offset)
    # Negative offset stands for going from the end,
    # positive - from the beginning
    aliases = []

    # States held inputs to the cell net
    recurrent_states = []

    for cell_input, _ in initial_cell_inputs:
        cell_input = str(cell_input)
        # Recurrent_states is going to be (T + 1) x ...
        # It stores all inputs and outputs of the cell net over time.
        # Or their gradients in the case of the backward pass.
        state = s(cell_input + "_states")
        states_grad = state + "_grad"
        cell_output = links[str(cell_input)]
        forward_links.append((cell_input, state, 0))
        forward_links.append((cell_output, state, 1))
        backward_links.append((cell_input + "_grad", states_grad, 0))
        backward_links.append((cell_output + "_grad", states_grad, 1))

        backward_cell_net.Proto().external_input.append(
            str(cell_output) + "_grad")
        aliases.append((state, cell_output + "_all", 1))
        aliases.append((state, cell_output + "_last", -1))
        all_outputs.extend([cell_output + "_all", cell_output + "_last"])

        recurrent_states.append(state)

    for input_t, input_blob in inputs:
        forward_links.append((str(input_t), str(input_blob), 0))
        backward_links.append(
            (backward_mapping[str(input_t)], str(input_blob) + "_grad", 0))
    backward_cell_net.Proto().external_input.extend(
        cell_net.Proto().external_input)
    backward_cell_net.Proto().external_input.extend(
        cell_net.Proto().external_output)

    def unpack_triple(x):
        if x:
            a, b, c = zip(*x)
            return a, b, c
        return [], [], []

    # Splitting to separate lists so we can pass them to c++
    # where we ensemle them back
    link_internal, link_external, link_offset = unpack_triple(forward_links)
    backward_link_internal, backward_link_external, backward_link_offset = \
        unpack_triple(backward_links)
    alias_src, alias_dst, alias_offset = unpack_triple(aliases)

    params = [x for x in references if x in backward_mapping.keys()]
    recurrent_inputs = [str(x[1]) for x in initial_cell_inputs]

    global _workspace_seq
    results = net.RecurrentNetwork(
        all_inputs,
        all_outputs + [s("step_workspaces_{}".format(_workspace_seq))],
        param=map(all_inputs.index, params),
        alias_src=alias_src,
        alias_dst=map(str, alias_dst),
        alias_offset=alias_offset,
        recurrent_states=recurrent_states,
        initial_recurrent_state_ids=map(all_inputs.index, recurrent_inputs),
        link_internal=map(str, link_internal),
        link_external=map(str, link_external),
        link_offset=link_offset,
        backward_link_internal=map(str, backward_link_internal),
        backward_link_external=map(str, backward_link_external),
        backward_link_offset=backward_link_offset,
        step_net=str(cell_net.Proto()),
        backward_step_net=str(backward_cell_net.Proto()),
        timestep="timestep" if timestep is None else str(timestep),
        outputs_with_grads=outputs_with_grads,
    )
    _workspace_seq += 1
    # The last output is a list of step workspaces,
    # which is only needed internally for gradient propogation
    return results[:-1]
示例#8
0
def recurrent_net(
        net, cell_net, inputs, initial_cell_inputs,
        links, timestep=None, scope=None, outputs_with_grads=(0,),
        recompute_blobs_on_backward=None, forward_only=False,
):
    '''
    net: the main net operator should be added to

    cell_net: cell_net which is executed in a recurrent fasion

    inputs: sequences to be fed into the recurrent net. Currently only one input
    is supported. It has to be in a format T x N x (D1...Dk) where T is lengths
    of the sequence. N is a batch size and (D1...Dk) are the rest of dimentions

    initial_cell_inputs: inputs of the cell_net for the 0 timestamp.
    Format for each input is:
        (cell_net_input_name, external_blob_with_data)

    links: a dictionary from cell_net input names in moment t+1 and
    output names of moment t. Currently we assume that each output becomes
    an input for the next timestep.

    timestep: name of the timestep blob to be used. If not provided "timestep"
    is used.

    scope: Internal blobs are going to be scoped in a format
    <scope_name>/<blob_name>
    If not provided we generate a scope name automatically

    outputs_with_grads : position indices of output blobs which will receive
    error gradient (from outside recurrent network) during backpropagation

    recompute_blobs_on_backward: specify a list of blobs that will be
                 recomputed for backward pass, and thus need not to be
                 stored for each forward timestep.

    forward_only: if True, only forward steps are executed
    '''
    assert len(inputs) == 1, "Only one input blob is supported so far"

    input_blobs = [str(i[0]) for i in inputs]
    initial_input_blobs = [str(x[1]) for x in initial_cell_inputs]
    op_name = net.NextName('recurrent')

    def s(name):
        # We have to manually scope due to our internal/external blob
        # relationships.
        scope_name = op_name if scope is None else scope
        return "{}/{}".format(str(scope_name), str(name))

    # determine inputs that are considered to be references
    # it is those that are not referred to in inputs or initial_cell_inputs
    known_inputs = [str(b) for b in input_blobs + initial_input_blobs]
    known_inputs += [str(x[0]) for x in initial_cell_inputs]
    if timestep is not None:
        known_inputs.append(str(timestep))
    references = [
        core.BlobReference(b) for b in cell_net.Proto().external_input
        if b not in known_inputs]

    inner_outputs = list(cell_net.Proto().external_output)
    # These gradients are expected to be available during the backward pass
    inner_outputs_map = {o: o + '_grad' for o in inner_outputs}

    # compute the backward pass of the cell net
    if not forward_only:
        backward_ops, backward_mapping = core.GradientRegistry.GetBackwardPass(
            cell_net.Proto().op, inner_outputs_map)
        backward_mapping = {str(k): v for k, v in viewitems(backward_mapping)}

        backward_cell_net = core.Net("RecurrentBackwardStep")
        del backward_cell_net.Proto().op[:]

        if recompute_blobs_on_backward is not None:
            # Insert operators to re-compute the specified blobs.
            # They are added in the same order as for the forward pass, thus
            # the order is correct.
            recompute_blobs_on_backward = {str(b) for b in
                                           recompute_blobs_on_backward}

            for op in cell_net.Proto().op:
                if not recompute_blobs_on_backward.isdisjoint(set(op.output)):
                    backward_cell_net.Proto().op.extend([op])
                    # This fires if other outputs than the declared
                    # are computed by the ops that are recomputed
                    assert set(op.output).issubset(recompute_blobs_on_backward)

        backward_cell_net.Proto().op.extend(backward_ops)
        # compute blobs used but not defined in the backward pass
        backward_ssa, backward_blob_versions = core.get_ssa(
            backward_cell_net.Proto())
        undefined = core.get_undefined_blobs(backward_ssa)

        # also add to the output list the intermediate outputs of fwd_step that
        # are used by backward.
        ssa, blob_versions = core.get_ssa(cell_net.Proto())
        scratches = [
            blob
            for blob, ver in viewitems(blob_versions)
            if (ver > 0 and
                blob in undefined and
                blob not in cell_net.Proto().external_output)
        ]
        backward_cell_net.Proto().external_input.extend(scratches)
        backward_cell_net.Proto().type = 'simple'
    else:
        backward_cell_net = None

    all_inputs = [i[1] for i in inputs] + [
        x[1] for x in initial_cell_inputs] + references
    all_outputs = []

    cell_net.Proto().type = 'simple'

    # Internal arguments used by RecurrentNetwork operator

    # Links are in the format blob_name, recurrent_states, offset.
    # In the moment t we know that corresponding data block is at
    # t + offset position in the recurrent_states tensor
    forward_links = []
    backward_links = []

    # Aliases are used to expose outputs to external world
    # Format (internal_blob, external_blob, offset)
    # Negative offset stands for going from the end,
    # positive - from the beginning
    aliases = []

    # States held inputs to the cell net
    recurrent_states = []

    for cell_input, _ in initial_cell_inputs:
        cell_input = str(cell_input)
        # Recurrent_states is going to be (T + 1) x ...
        # It stores all inputs and outputs of the cell net over time.
        # Or their gradients in the case of the backward pass.
        state = s(cell_input + "_states")
        states_grad = state + "_grad"
        cell_output = links[str(cell_input)]
        forward_links.append((cell_input, state, 0))
        forward_links.append((cell_output, state, 1))

        aliases.append((state, cell_output + "_all", 1))
        aliases.append((state, cell_output + "_last", -1))
        all_outputs.extend([cell_output + "_all", cell_output + "_last"])

        recurrent_states.append(state)

        if backward_cell_net is not None:
            backward_links.append((cell_output + "_grad", states_grad, 1))
            backward_cell_net.Proto().external_input.append(
                str(cell_output) + "_grad")

            recurrent_input_grad = cell_input + "_grad"
            if not backward_blob_versions.get(recurrent_input_grad, 0):
                # If nobody writes to this recurrent input gradient, we need
                # to make sure it gets to the states grad blob after all.
                # We do this by using backward_links which triggers an alias
                # This logic is being used for example in a SumOp case
                backward_links.append(
                    (backward_mapping[cell_input], states_grad, 0))
            else:
                backward_links.append((recurrent_input_grad, states_grad, 0))


    for input_t, input_blob in inputs:
        forward_links.append((str(input_t), str(input_blob), 0))

    if backward_cell_net is not None:
        for input_t, input_blob in inputs:
            backward_links.append((
                backward_mapping[str(input_t)], str(input_blob) + "_grad", 0
            ))
        backward_cell_net.Proto().external_input.extend(
            cell_net.Proto().external_input)
        backward_cell_net.Proto().external_input.extend(
            cell_net.Proto().external_output)

    def unpack_triple(x):
        if x:
            a, b, c = zip(*x)
            return a, b, c
        return [], [], []

    # Splitting to separate lists so we can pass them to c++
    # where we ensemle them back
    link_internal, link_external, link_offset = unpack_triple(forward_links)
    alias_src, alias_dst, alias_offset = unpack_triple(aliases)

    recurrent_inputs = [str(x[1]) for x in initial_cell_inputs]

    # Make sure that recurrent gradients accumulate with internal gradients
    # (if a blob in the backward_cell_net receives gradient from both an
    # external connection as well as from within the backward_cell_net,
    # those gradients need to be added together, rather than one overwriting
    # the other)
    if backward_cell_net is not None:
        proto = backward_cell_net.Proto()
        operators = []
        while len(proto.op) > 0:
            op = proto.op[-1]
            proto.op.remove(op)
            operators.append(op)
        for op in operators[::-1]:
            proto.op.extend([op])
            for j, output_blob in enumerate(op.output):
                if output_blob in proto.external_input:
                    # In place operation won't cause issues because it takes
                    # existing value of a blob into account
                    if output_blob in op.input:
                        continue
                    output_blob = core.BlobReference(output_blob)
                    accum_blob = output_blob + "_accum"
                    proto.op[-1].output[j] = str(accum_blob)
                    backward_cell_net.Sum(
                        [output_blob, accum_blob],
                        [output_blob],
                    )

    def map_to_dual_list(m):
        return [str(x) for x in list(m.keys())] + \
               [str(x) for x in list(m.values())]

    backward_args = {}
    if backward_cell_net is not None:
        backward_mapping_keys = set(viewkeys(backward_mapping))
        backward_link_internal, backward_link_external, backward_link_offset = \
            unpack_triple(backward_links)
        params = [x for x in references if x in backward_mapping_keys]
        param_grads = [
            str(backward_mapping[x])
            for x in references
            if x in backward_mapping_keys
        ]
        if recompute_blobs_on_backward is None:
            recompute_blobs_on_backward = set()
        backward_args = {
            'param': [all_inputs.index(p) for p in params],
            'backward_link_internal': [str(l) for l in backward_link_internal],
            'backward_link_external': [str(l) for l in backward_link_external],
            'backward_link_offset': backward_link_offset,
            'outputs_with_grads': outputs_with_grads,
            'recompute_blobs_on_backward': [
                str(b) for b in recompute_blobs_on_backward
            ],
            'param_grads': param_grads,
        }
        if len(backward_cell_net.Proto().op) != 0:
            backward_args['backward_step_net'] = backward_cell_net.Proto()


    results = net.RecurrentNetwork(
        all_inputs,
        all_outputs + [s("step_workspaces")],
        alias_src=alias_src,
        alias_dst=[str(a) for a in alias_dst],
        alias_offset=alias_offset,
        recurrent_states=recurrent_states,
        initial_recurrent_state_ids=[
            all_inputs.index(i) for i in recurrent_inputs
        ],
        link_internal=[str(l) for l in link_internal],
        link_external=[str(l) for l in link_external],
        link_offset=link_offset,
        enable_rnn_executor=1,
        step_net=cell_net.Proto(),
        timestep="timestep" if timestep is None else str(timestep),
        **backward_args
    )

    # Restore net type since 'rnn' is not recognized outside RNNs
    cell_net.Proto().type = 'simple'

    # The last output is a list of step workspaces,
    # which is only needed internally for gradient propogation
    return results[:-1]
示例#9
0
def recurrent_net(net,
                  cell_net,
                  inputs,
                  initial_cell_inputs,
                  links,
                  scratch_sizes,
                  timestep=None,
                  scope=None):
    '''
    net: the main net operator should be added to

    cell_net: cell_net which is executed in a recurrent fasion

    inputs: sequences to be fed into the recurrent net. Currently only one input
    is supported. It has to be in a format T x N x (D1...Dk) where T is lengths
    of the sequence. N is a batch size and (D1...Dk) are the rest of dimentions

    initial_cell_inputs: inputs of the cell_net for the 0 timestamp.
    Format for each input is:
        (cell_net_input_name, external_blob_with_data, input_size)

    links: a dictionary from cell_net input names in moment t+1 and
    output names of moment t. Currently we assume that each output becomes
    an input for the next timestep.

    scratch_sizes: sizes of the scratch blobs. Scratch blobs are those
    intermidiate blobs of the cell_net which are used in backward pass.
    We use sizes iformation to preallocate memory for them over time.

    For example in case of LSTM we have FC -> Sum ->LSTMUnit sequence of
    operations in each iteration of the cell net. Output of Sum is an
    intermidiate blob. Also it is going to be part of the backward pass.
    Thus it is a scratch blob size of which we must to pvovide.

    timestep: name of the timestep blob to be used. If not provided "timestep"
    is used.

    scope: Internal blobs are going to be scoped in a format
    <scope_name>/<blob_name>
    If not provided we generate a scope name automatically
    '''
    assert len(inputs) == 1, "Only one input blob is supported so far"

    input_blobs = [str(i[0]) for i in inputs]
    initial_input_blobs = [str(x[1]) for x in initial_cell_inputs]
    op_name = net.NextName('recurrent')

    def s(name):
        # We have to manually scope due to our internal/external blob
        # relationships.
        scope_name = op_name if scope is None else scope
        return "{}/{}".format(str(scope_name), str(name))

    # determine inputs that are considered to be references
    # it is those that are not referred to in inputs or initial_cell_inputs
    known_inputs = map(str, input_blobs + initial_input_blobs)
    known_inputs += [str(x[0]) for x in initial_cell_inputs]
    if timestep is not None:
        known_inputs.append(str(timestep))
    references = [
        b for b in cell_net.Proto().external_input if b not in known_inputs
    ]

    inner_outputs = list(cell_net.Proto().external_output)
    # These gradients are expected to be available during the backward pass
    inner_outputs_map = {o: o + '_grad' for o in inner_outputs}

    # compute the backward pass of the cell net
    backward_ops, backward_mapping = core.GradientRegistry.GetBackwardPass(
        cell_net.Proto().op, inner_outputs_map)
    backward_mapping = {str(k): str(v) for k, v in backward_mapping.items()}
    backward_cell_net = core.Net("RecurrentBackwardStep")

    del backward_cell_net.Proto().op[:]
    backward_cell_net.Proto().op.extend(backward_ops)

    # compute blobs used but not defined in the backward pass
    ssa, _ = core.get_ssa(backward_cell_net.Proto())
    undefined = core.get_undefined_blobs(ssa)

    # also add to the output list the intermediate outputs of fwd_step that
    # are used by backward.
    ssa, blob_versions = core.get_ssa(cell_net.Proto())
    scratches = [
        blob for (blob, ver) in blob_versions.items() if ver > 0
        and blob in undefined and blob not in cell_net.Proto().external_output
    ]

    all_inputs = [i[1] for i in inputs] + [x[1] for x in initial_cell_inputs
                                           ] + references
    all_outputs = []

    cell_net.Proto().type = 'simple'
    backward_cell_net.Proto().type = 'simple'

    # Internal arguments used by RecurrentNetwork operator

    # Links are in the format blob_name, recurrent_states, offset.
    # In the moment t we know that corresponding data block is at
    # t + offset position in the recurrent_states tensor
    forward_links = []
    backward_links = []

    # Aliases are used to expose outputs to external world
    # Format (internal_blob, external_blob, offset)
    # Negative offset stands for going from the end,
    # positive - from the beginning
    aliases = []
    backward_aliases = []

    # States held inputs to the cell net
    recurrent_states = []

    # a map from gradient blob name to blob with its value over time
    grad_to_state = {}

    # A mapping from a blob to its gradient state blob

    for cell_input, _, size in initial_cell_inputs:
        cell_input = str(cell_input)
        # Recurrent_states is going to be (T + 1) x ...
        # It stores all inputs and outputs of the cell net over time.
        # Or their gradients in the case of the backward pass.
        state = s(cell_input + "_states")
        states_grad = state + "_grad"
        cell_output = links[str(cell_input)]
        forward_links.append((cell_input, state, 0))
        forward_links.append((cell_output, state, 1))
        backward_links.append((cell_input + "_grad", states_grad, 0))
        backward_links.append((cell_output + "_grad", states_grad, 1))

        backward_cell_net.Proto().external_input.append(
            str(cell_output) + "_grad")
        aliases.append((state, cell_output + "_last", -1))
        aliases.append((state, cell_output + "_all", 1))
        all_outputs.extend([cell_output + "_all", cell_output + "_last"])

        recurrent_states.append(state)

    for scratch in scratches:
        # no scoping as scratches should be already scoped
        forward_links.append((scratch, scratch + "_states", 0))
        grad_blob = scratch + "_grad"
        states_grad_blob = scratch + "_states_grad"
        backward_links.append((grad_blob, states_grad_blob, 0))
        backward_cell_net.Proto().external_input.append(scratch)
        grad_to_state[grad_blob] = states_grad_blob

    input_gradient_ids = []
    for input_id, (input_t, input_blob) in enumerate(inputs):
        forward_links.append((str(input_t), str(input_blob), 0))

        input_blob_grad = str(input_blob) + "_grad"
        if backward_mapping[str(input_t)] != str(input_t) + "_grad":
            # Some scratch (internal blob) ends up being an input gradient
            # So we avoid extra copy and reuse it by applying this alias
            backward_aliases.append(
                (grad_to_state[backward_mapping[str(input_t)]],
                 input_blob_grad, 0))
        else:
            # This is a general case - we have to explicitly create input
            # gradient blob as it doesn't match any of internal gradients
            backward_links.append((str(input_t) + "_grad", input_blob_grad, 0))
            input_gradient_ids.append(input_id)

    backward_cell_net.Proto().external_input.extend(
        cell_net.Proto().external_input)
    backward_cell_net.Proto().external_input.extend(
        cell_net.Proto().external_output)

    def unpack_triple(x):
        if x:
            a, b, c = zip(*x)
            return a, b, c
        return [], [], []

    # Splitting to separate lists so we can pass them to c++
    # where we ensemle them back
    link_internal, link_external, link_offset = unpack_triple(forward_links)
    backward_link_internal, backward_link_external, backward_link_offset = \
        unpack_triple(backward_links)
    alias_src, alias_dst, alias_offset = unpack_triple(aliases)
    backward_alias_src, backward_alias_dst, backward_alias_offset = \
        unpack_triple(backward_aliases)

    params = [x for x in references if x in backward_mapping.keys()]

    return net.RecurrentNetwork(
        all_inputs,
        all_outputs,
        param=params,
        param_gradient=[backward_mapping[p] for p in params],
        alias_src=alias_src,
        alias_dst=map(str, alias_dst),
        alias_offset=alias_offset,
        recurrent_states=recurrent_states,
        recurrent_inputs=[str(x[1]) for x in initial_cell_inputs],
        recurrent_sizes=[int(x[2]) for x in initial_cell_inputs],
        link_internal=map(str, link_internal),
        link_external=map(str, link_external),
        link_offset=link_offset,
        backward_link_internal=map(str, backward_link_internal),
        backward_link_external=map(str, backward_link_external),
        backward_link_offset=backward_link_offset,
        backward_alias_src=backward_alias_src,
        backward_alias_dst=backward_alias_dst,
        backward_alias_offset=backward_alias_offset,
        scratch=[sc + "_states" for sc in scratches],
        backward_scratch=[sc + "_states_grad" for sc in scratches],
        scratch_sizes=scratch_sizes,
        step_net=str(cell_net.Proto()),
        backward_step_net=str(backward_cell_net.Proto()),
        timestep="timestep" if timestep is None else str(timestep),
        input_gradient_ids=input_gradient_ids,
    )