예제 #1
0
 def concatenation(self, node: onnx.NodeProto) -> spec.Spec:
     input_shapes, _, attributes = self.get_inputs_for_gen_spec(node)
     operator_spec_option = spec.Concatenation(
         tensors=list(map(list, input_shapes)),
         axis=spec_utils.implicit_axis_to_explicit(attributes['axis'],
                                                   input_shapes[0]))
     return spec.Spec(spec_utils.node_identifier(node),
                      operator_spec_option)
예제 #2
0
 def flatten(self, node: onnx.NodeProto) -> spec.Spec:
     input_shapes, _, attributes = self.get_inputs_for_gen_spec(node)
     assert len(input_shapes) == 1
     input_shape = input_shapes[0]
     operator_spec_option = spec.Flatten(
         shape=[*input_shape],
         axis=spec_utils.implicit_axis_to_explicit(attributes['axis'],
                                                   input_shape))
     return spec.Spec(spec_utils.node_identifier(node),
                      operator_spec_option)
예제 #3
0
 def transpose(self, node: onnx.NodeProto) -> spec.Spec:
     input_shapes, _, attributes = self.get_inputs_for_gen_spec(node)
     assert len(input_shapes) == 1
     input_shape = input_shapes[0]
     operator_spec_option = spec.Transpose(
         shape=[*input_shape],
         permutation=spec_utils.implicit_axis_to_explicit(
             [*attributes['perm']], input_shape))
     return spec.Spec(spec_utils.node_identifier(node),
                      operator_spec_option)
예제 #4
0
    def softmax(self, node: onnx.NodeProto) -> spec.Spec:
        input_shapes, _, attributes = self.get_inputs_for_gen_spec(node)
        assert len(input_shapes) == 1
        input_shape = input_shapes[0]

        operator_spec_option = spec.Softmax(
            input_shape=[*input_shape],
            beta=attributes.get('beta', float(1.0)),
            axis=spec_utils.implicit_axis_to_explicit(attributes['axis'],
                                                      input_shape))
        return spec.Spec(spec_utils.node_identifier(node),
                         operator_spec_option)