def _create_gru(cls, init_model, pred_model, n, opset_version): assert init_model is not None, "cannot convert GRUs without access to the full model" assert pred_model is not None, "cannot convert GRUs without access to the full model" attrs = dict(n.attrs) # make a copy, which is safe to mutate hidden_size = attrs.pop('hidden_size') linear_before_reset = attrs.pop('linear_before_reset') assert not attrs, "unsupported GRU attributes: " + str(attrs.keys()) input_blob, W, R, B, sequence_lens, initial_h = n.inputs if sequence_lens == "": sequence_lens = None input_size = cls._rnn_shape_inference(init_model, pred_model, n, input_blob, W) if input_size is None: raise RuntimeError( "best-effort shape inference for GRU input failed") name = dummy_name() init_net = core.Net("init-net") pred_mh = ModelHelper() hidden_t_all, hidden_t_last = gru_cell.GRU( pred_mh, input_blob, sequence_lens, [initial_h], input_size, hidden_size, name, drop_states=True, forward_only=True, linear_before_reset=linear_before_reset) # input and recurrence biases are squashed together in onnx but not in caffe2 Bi = name + "_bias_i2h" Br = name + "_bias_gates" init_net.Slice(B, Bi, starts=[0 * hidden_size], ends=[3 * hidden_size]) init_net.Slice(B, Br, starts=[3 * hidden_size], ends=[6 * hidden_size]) # caffe2 has a different order from onnx. We need to rearrange # z r h -> r z h # # TODO implement support for return_params in gru_cell.GRU. # Until then, hardcode blob names. reforms = ((W, 'i2h_w', True, [ (0, input_size) ]), (R, 'gate_t_w', False, [(0, hidden_size)]), (Bi, 'i2h_b', True, []), (Br, 'gate_t_b', False, [])) for name_from, name_to, do_concat, extra_dims in reforms: xz, xr, xh = [ '%s/%s_%s' % (name, prefix, name_to) for prefix in ('update', 'reset', 'output') ] for i, x in enumerate([xz, xr, xh]): dim0 = i * hidden_size, (i + 1) * hidden_size starts, ends = zip(dim0, *extra_dims) init_net.Slice(name_from, x, starts=starts, ends=ends) if do_concat: init_net.Concat([xr, xz, xh], ['%s/%s' % (name, name_to), dummy_name()], axis=0) pred_mh.net = pred_mh.net.Clone("dummy-clone-net", blob_remap={ hidden_t_all: n.outputs[0], hidden_t_last: n.outputs[1] }) return Caffe2Ops(list(pred_mh.Proto().op), list(init_net.Proto().op), list(pred_mh.Proto().external_input))
def make_cell(*args, **kwargs): return gru_cell.GRU(*args, linear_before_reset=linear_before_reset, **kwargs)
def make_gru(direction_offset): name = dummy_name() # input and recurrence biases are squashed together in # onnx but not in caffe2 bias_offset = 6 * direction_offset * hidden_size Bi = init_net.Slice(B, name + "_bias_i2h", starts=[bias_offset + 0 * hidden_size], ends =[bias_offset + 3 * hidden_size]) Br = init_net.Slice(B, name + "_bias_gates", starts=[bias_offset + 3 * hidden_size], ends =[bias_offset + 6 * hidden_size]) weight_offset = 3 * direction_offset * hidden_size W_ = init_net.Slice(W, name + '/i2h_w_pre', starts=[weight_offset + 0 * hidden_size, 0], ends =[weight_offset + 3 * hidden_size,-1]) R_ = init_net.Slice(R, name + '/gates_t_w_pre', starts=[weight_offset + 0 * hidden_size, 0], ends =[weight_offset + 3 * hidden_size,-1]) # caffe2 has a different order from onnx. We need to rearrange # z r h -> r z h reforms = ((W_, 'i2h_w', True, [(0,-1)]), (R_, 'gate_t_w', False, [(0,-1)]), (Bi, 'i2h_b', True, []), (Br, 'gate_t_b', False, [])) for name_from, name_to, do_concat, extra_dims in reforms: xz, xr, xh = ['%s/%s_%s' % (name, prefix, name_to) for prefix in ('update', 'reset', 'output')] for i, x in enumerate([xz, xr, xh]): dim0 = i * hidden_size, (i+1) * hidden_size starts, ends = zip(dim0, *extra_dims) init_net.Slice(name_from, x, starts=starts, ends=ends) if do_concat: init_net.Concat([xr, xz, xh], ['%s/%s' % (name, name_to), dummy_name()], axis=0) initial_h_sliced = name + '/initial_h' init_net.Slice(initial_h, initial_h_sliced, starts=[direction_offset + 0, 0, 0], ends =[direction_offset + 1,-1,-1]) if direction_offset == 1: input = pred_mh.net.ReversePackedSegs( [input_blob, sequence_lens], name + "/input-reversed") else: input = input_blob hidden_t_all, hidden_t_last = gru_cell.GRU( pred_mh, input, sequence_lens, [initial_h_sliced], input_size, hidden_size, name, drop_states=True, forward_only=True, linear_before_reset=linear_before_reset ) if direction_offset == 1: hidden_t_all = pred_mh.net.ReversePackedSegs( [hidden_t_all, sequence_lens], name + "/output-reversed") return hidden_t_all, hidden_t_last