def wrapper_fn(*args):
            seqs = args[:n_seq]
            recs = args[n_seq: n_seq + n_rec]
            nonseq = args[n_seq + n_rec:]
            nest_seqs = utils.pack_sequence_as(nest_sequences, seqs)
            nest_recs = utils.pack_sequence_as(nest_rec_info, recs)
            nest_nonseq = utils.pack_sequence_as(nest_non_sequences, nonseq)
            newargs = list(nest_seqs) + list(nest_recs) + list(nest_nonseq)

            nest_outs = inner_fn(*newargs)

            return utils.flatten(nest_outs)
Beispiel #2
0
def select_nbest(nested, indices):
    if not isinstance(nested, (list, tuple)):
        return nested[indices]
    flat_list = flatten(nested)
    selected_list = [item[indices] for item in flat_list]

    return pack_sequence_as(nested, selected_list)
Beispiel #3
0
    def wrapper(*inputs):
        inputs = utils.flatten(inputs)
        outputs = fn(*inputs)

        if post_proc:
            return outputs[0]

        return utils.pack_sequence_as(nest_outputs, outputs)
def scan(fn, sequences=None, outputs_info=None, non_sequences=None,
         return_updates=False, use_extension=False, **kwargs):
    if sequences is None:
        sequences = []

    if outputs_info is None:
        outputs_info = []

    if non_sequences is None:
        non_sequences = []

    # support nested structure for sequences, outputs_info and non_sequences
    if use_extension:
        if isinstance(outputs_info, dict):
            raise ValueError("only support nested structure, not dict")

        nest_sequences = sequences
        nest_outputs_info = outputs_info
        nest_non_sequences = non_sequences

        # inputs to Theano's scan
        sequences = utils.flatten(nest_sequences)
        outputs_info = utils.flatten(nest_outputs_info)
        non_sequences = utils.flatten(nest_non_sequences)

        # input structure for fn
        nest_rec_info = []

        for item in nest_outputs_info:
            if item is not None:
                nest_rec_info.append(item)

        rec_info = utils.flatten(nest_rec_info)

        n_seq = len(sequences)
        n_rec = len(rec_info)

        for item in rec_info:
            if item is not None:
                continue
            raise ValueError("None can only appear in the outer level of "
                             "outputs_info")

        inner_fn = fn

        def wrapper_fn(*args):
            seqs = args[:n_seq]
            recs = args[n_seq: n_seq + n_rec]
            nonseq = args[n_seq + n_rec:]
            nest_seqs = utils.pack_sequence_as(nest_sequences, seqs)
            nest_recs = utils.pack_sequence_as(nest_rec_info, recs)
            nest_nonseq = utils.pack_sequence_as(nest_non_sequences, nonseq)
            newargs = list(nest_seqs) + list(nest_recs) + list(nest_nonseq)

            nest_outs = inner_fn(*newargs)

            return utils.flatten(nest_outs)

        fn = wrapper_fn

    outputs, updates = theano.scan(fn, sequences, outputs_info, non_sequences,
                                   **kwargs)

    if use_extension:
        outputs = utils.pack_sequence_as(nest_outputs_info, outputs)

    if "key" not in kwargs or not kwargs["key"]:
        key = "training"
    else:
        key = kwargs["key"]

    if return_updates:
        return outputs, updates

    add_to_collection(_SCAN_UPDATES_KEYS + "/" + key, updates)

    return outputs