Exemple #1
0
    def _create_gru(cls, init_model, pred_model, n, opset_version):
        assert init_model is not None, "cannot convert GRUs without access to the full model"
        assert pred_model is not None, "cannot convert GRUs without access to the full model"

        attrs = dict(n.attrs)  # make a copy, which is safe to mutate
        hidden_size = attrs.pop('hidden_size')
        linear_before_reset = attrs.pop('linear_before_reset')
        assert not attrs, "unsupported GRU attributes: " + str(attrs.keys())

        input_blob, W, R, B, sequence_lens, initial_h = n.inputs

        if sequence_lens == "":
            sequence_lens = None

        input_size = cls._rnn_shape_inference(init_model, pred_model, n,
                                              input_blob, W)
        if input_size is None:
            raise RuntimeError(
                "best-effort shape inference for GRU input failed")

        name = dummy_name()
        init_net = core.Net("init-net")
        pred_mh = ModelHelper()

        hidden_t_all, hidden_t_last = gru_cell.GRU(
            pred_mh,
            input_blob,
            sequence_lens, [initial_h],
            input_size,
            hidden_size,
            name,
            drop_states=True,
            forward_only=True,
            linear_before_reset=linear_before_reset)

        # input and recurrence biases are squashed together in onnx but not in caffe2
        Bi = name + "_bias_i2h"
        Br = name + "_bias_gates"
        init_net.Slice(B, Bi, starts=[0 * hidden_size], ends=[3 * hidden_size])
        init_net.Slice(B, Br, starts=[3 * hidden_size], ends=[6 * hidden_size])

        # caffe2 has a different order from onnx. We need to rearrange
        #  z r h  -> r z h
        #
        # TODO implement support for return_params in gru_cell.GRU.
        # Until then, hardcode blob names.
        reforms = ((W, 'i2h_w', True, [
            (0, input_size)
        ]), (R, 'gate_t_w', False, [(0, hidden_size)]),
                   (Bi, 'i2h_b', True, []), (Br, 'gate_t_b', False, []))
        for name_from, name_to, do_concat, extra_dims in reforms:
            xz, xr, xh = [
                '%s/%s_%s' % (name, prefix, name_to)
                for prefix in ('update', 'reset', 'output')
            ]
            for i, x in enumerate([xz, xr, xh]):
                dim0 = i * hidden_size, (i + 1) * hidden_size
                starts, ends = zip(dim0, *extra_dims)
                init_net.Slice(name_from, x, starts=starts, ends=ends)
            if do_concat:
                init_net.Concat([xr, xz, xh],
                                ['%s/%s' % (name, name_to),
                                 dummy_name()],
                                axis=0)

        pred_mh.net = pred_mh.net.Clone("dummy-clone-net",
                                        blob_remap={
                                            hidden_t_all: n.outputs[0],
                                            hidden_t_last: n.outputs[1]
                                        })

        return Caffe2Ops(list(pred_mh.Proto().op), list(init_net.Proto().op),
                         list(pred_mh.Proto().external_input))
Exemple #2
0
 def make_cell(*args, **kwargs):
     return gru_cell.GRU(*args,
                         linear_before_reset=linear_before_reset,
                         **kwargs)
Exemple #3
0
        def make_gru(direction_offset):
            name = dummy_name()

            # input and recurrence biases are squashed together in
            # onnx but not in caffe2

            bias_offset = 6 * direction_offset * hidden_size
            Bi = init_net.Slice(B, name + "_bias_i2h",
                                starts=[bias_offset + 0 * hidden_size],
                                ends  =[bias_offset + 3 * hidden_size])
            Br = init_net.Slice(B, name + "_bias_gates",
                                starts=[bias_offset + 3 * hidden_size],
                                ends  =[bias_offset + 6 * hidden_size])

            weight_offset = 3 * direction_offset * hidden_size
            W_ = init_net.Slice(W, name + '/i2h_w_pre',
                                starts=[weight_offset + 0 * hidden_size, 0],
                                ends  =[weight_offset + 3 * hidden_size,-1])
            R_ = init_net.Slice(R, name + '/gates_t_w_pre',
                                starts=[weight_offset + 0 * hidden_size, 0],
                                ends  =[weight_offset + 3 * hidden_size,-1])

            # caffe2 has a different order from onnx. We need to rearrange
            #  z r h  -> r z h
            reforms = ((W_, 'i2h_w',    True,  [(0,-1)]),
                       (R_, 'gate_t_w', False, [(0,-1)]),
                       (Bi, 'i2h_b',    True,  []),
                       (Br, 'gate_t_b', False, []))
            for name_from, name_to, do_concat, extra_dims in reforms:
                xz, xr, xh = ['%s/%s_%s' % (name, prefix, name_to) for prefix in ('update', 'reset', 'output')]
                for i, x in enumerate([xz, xr, xh]):
                    dim0 = i * hidden_size, (i+1) * hidden_size
                    starts, ends = zip(dim0, *extra_dims)
                    init_net.Slice(name_from, x, starts=starts, ends=ends)
                if do_concat:
                    init_net.Concat([xr, xz, xh], ['%s/%s' % (name, name_to), dummy_name()], axis=0)

            initial_h_sliced = name + '/initial_h'
            init_net.Slice(initial_h, initial_h_sliced,
                           starts=[direction_offset + 0, 0, 0],
                           ends  =[direction_offset + 1,-1,-1])

            if direction_offset == 1:
                input = pred_mh.net.ReversePackedSegs(
                    [input_blob, sequence_lens], name + "/input-reversed")
            else:
                input = input_blob

            hidden_t_all, hidden_t_last = gru_cell.GRU(
                pred_mh,
                input,
                sequence_lens,
                [initial_h_sliced],
                input_size,
                hidden_size,
                name,
                drop_states=True,
                forward_only=True,
                linear_before_reset=linear_before_reset
            )

            if direction_offset == 1:
                hidden_t_all = pred_mh.net.ReversePackedSegs(
                    [hidden_t_all, sequence_lens], name + "/output-reversed")

            return hidden_t_all, hidden_t_last