예제 #1
0
 def extract(cls, node):
     attrs = {
         'data_type': node.value.dtype,
         'value': node.value,
     }
     Const.update_node_stat(node, attrs)
     return cls.enabled
예제 #2
0
 def extract(cls, node):
     pb_tensor = node.pb.attr["value"].tensor
     shape = tf_tensor_shape(pb_tensor.tensor_shape)
     attrs = {
         'shape': shape,
         'value': tf_tensor_content(pb_tensor.dtype, shape, pb_tensor),
         'data_type': tf_dtype_extractor(pb_tensor.dtype),
     }
     Const.update_node_stat(node, attrs)
     return cls.enabled
예제 #3
0
    def extract(cls, node):
        pb_value = onnx_attr(node, 'value', 't')
        value = numpy_helper.to_array(pb_value)

        attrs = {
            'data_type': value.dtype,
            'value': value,
        }
        Const.update_node_stat(node, attrs)
        return cls.enabled
예제 #4
0
 def replace_sub_graph(self, graph: Graph, match: dict):
     node = match['op']
     if not node.has_valid('value'):
         log.debug("No value in FakeConst node {}".format(node.id))
         return
     node_value = node.value
     extracted_attrs = {
         'data_type': tf_dtype_extractor(node.pb.attr['dtype'].type),
         'shape': int64_array(node_value.shape),
         'value': node_value
     }
     Const.update_node_stat(node, extracted_attrs)
     log.debug(
         'FakeConst op was translated to Const op with shape = {} and value.shape = {}'
         ''.format(extracted_attrs['shape'],
                   extracted_attrs['value'].shape))
예제 #5
0
    def extract(cls, node):
        attrs = get_mxnet_layer_attrs(node.symbol_dict)
        shape = list(attrs.tuple('shape', int, None))
        zero_shapes = []
        for i, s in enumerate(shape):
            if s == 0:
                shape[i] = 1
                zero_shapes.append(i)

        update_attrs = {
            'shape': np.ndarray(shape),
            'value': np.zeros(shape),
            'zero_shapes': zero_shapes
        }

        # update the attributes of the node
        Const.update_node_stat(node, update_attrs)
        return cls.enabled
예제 #6
0
 def extract(cls, node):
     if 'value' in node.symbol_dict:
         Const.update_node_stat(node, {'value': node.symbol_dict['value']})
     else:
         Parameter.update_node_stat(node, {})
     return cls.enabled
예제 #7
0
 def extract(cls, node):
     value = to_array(node.pb_init)
     attrs = {'data_type': value.dtype, 'value': value}
     Const.update_node_stat(node, attrs)
     return cls.enabled