示例#1
0
 def extract(cls, node: Node):
     attrs = get_mxnet_layer_attrs(node.symbol_dict)
     Cast.update_node_stat(
         node, {'dst_type': np.dtype(attrs.str('dtype', 'float32'))})
     return cls.enabled
示例#2
0
 def extract(cls, node):
     to = onnx_attr(node, 'to', 'i', default=None)
     Cast.update_node_stat(node, {'dst_type': get_onnx_datatype_as_numpy(to)})
     return cls.enabled
示例#3
0
 def extract(cls, node):
     cast_dst_type = tf_data_type_decode[node.pb.attr['DstT'].type][0]
     Cast.update_node_stat(node, {'dst_type': cast_dst_type})
     return cls.enabled
示例#4
0
 def extract(cls, node: Node):
     attrs = {
         'dst_type': node.module.dst_type,
     }
     Cast.update_node_stat(node, attrs)
     return cls.enabled