コード例 #1
0
 def extract(cls, node):
     attrs = {
         'data_type': tf_dtype_extractor(node.pb.attr["dtype"].type),
         'shape': tf_tensor_shape(node.pb.attr["shape"].shape),
         'identity': True,
         'infer': lambda node: copy_shape_infer(node, value_infer=copy_value),
     }
     Op.update_node_stat(node, attrs)
     return cls.enabled
コード例 #2
0
 def extract(cls, node):
     shapes = node.pb.attr['shapes'].list.shape
     tf_types = node.pb.attr['component_types'].list.type
     extracted_types = []
     for t in tf_types:
         extracted_types.append(tf_dtype_extractor(t))
     result_shapes = []
     for shape_pb in shapes:
         shape = shape_pb.dim
         if len(shape) == 3:
             result_shapes.append(int64_array([1, shape[0].size, shape[1].size, shape[2].size]))
         else:
             result_shapes.append(int64_array([dim.size for dim in shape]))
     Op.update_node_stat(node, {'shapes': result_shapes, 'types': extracted_types})
     return cls.enabled
コード例 #3
0
 def extract(cls, node):
     shapes = node.pb.attr['output_shapes'].list.shape
     tf_types = node.pb.attr['output_types'].list.type
     extracted_types = []
     for t in tf_types:
         extracted_types.append(tf_dtype_extractor(t))
     result_shapes = []
     for shape_pb in shapes:
         result_shapes.append(tf_tensor_shape(shape_pb))
     Op.update_node_stat(
         node, {
             'shapes': result_shapes,
             'types': extracted_types,
             'out_ports_count': 1
         })
     return cls.enabled
コード例 #4
0
    def extract(cls, node):
        narrow_range = node.pb.attr['narrow_range'].b
        num_bits = node.pb.attr['num_bits'].i
        levels = 2**num_bits - int(narrow_range)

        # we prepare this operation to be converted to FakeQuantize op,
        # but input reconnection is needed, so we don't set infer function and type attribute
        Op.update_node_stat(
            node, {
                'op': 'FakeQuantWithMinMaxVars',
                'levels': levels,
                'narrow_range': narrow_range,
                'num_bits': num_bits
            })

        return cls.enabled
コード例 #5
0
 def extract(cls, node):
     attrs = get_attrs(node)
     Op.update_node_stat(node, attrs)
     return cls.enabled
コード例 #6
0
 def extract(cls, node):
     Op.update_node_stat(node, {'op': 'FakeConst'})
     return cls.enabled