Ejemplo n.º 1
0
 def __init__(self, graph: Graph, attrs: dict):
     super().__init__(
         graph, {
             'kind':
             'op',
             'type':
             self.op,
             'op':
             self.op,
             'version':
             'opset6',
             'eps':
             None,
             'normalize_variance':
             None,
             'eps_mode':
             None,
             'in_ports_count':
             2,
             'out_ports_count':
             1,
             'infer':
             self.infer,
             'reverse_infer':
             lambda node: reverse_bypass_infer(node, in_ports=[0]),
         }, attrs)
Ejemplo n.º 2
0
    def __init__(self, graph: Graph, attrs: dict):
        assert 'alpha' in attrs, 'LRN operation should have `alpha` parameter set while creation'
        assert 'beta' in attrs, 'LRN operation should have `beta` parameter set while creation'
        assert 'bias' in attrs, 'LRN operation should have `bias` parameter set while creation'
        assert 'size' in attrs, 'LRN operation should have `size` parameter set while creation'
        assert 'region' not in attrs, \
            'LRN operation should not have `region` parameter set while creation, please use AttributedLRN operation ' \
            'instead or keep using LRN operation with region expressed as second `axis`-input'

        super().__init__(
            graph, {
                'type':
                self.op,
                'op':
                self.op,
                'version':
                'opset1',
                'infer':
                self.infer,
                'reverse_infer':
                lambda node: reverse_bypass_infer(node, in_ports=[0]),
                'in_ports_count':
                2,
                'out_ports_count':
                1,
            }, attrs)
Ejemplo n.º 3
0
 def __init__(self, graph: Graph, attrs: dict):
     super().__init__(graph, {
         'type': None,
         'op': self.op,
         'in_ports_count': 5,
         'out_ports_count': 1,
         'infer': self.infer,
         'reverse_infer': lambda node: reverse_bypass_infer(node, in_ports=[0]),
     }, attrs)
Ejemplo n.º 4
0
def batch_norm_ext(attrs):
    node_attrs = {
        'type': 'BatchNormalization',
        'eps': attrs.float('eps', 0.001),
        'infer': batch_norm_4_infer,
        'reverse_infer': lambda node: reverse_bypass_infer(node, in_ports=[0]),
        'fix_gamma': attrs.bool('fix_gamma', False)
    }
    node_attrs.update(layout_attrs())
    return node_attrs
Ejemplo n.º 5
0
 def __init__(self, graph: Graph, attrs: dict):
     mandatory_props = {
         'type': None,
         'axis': None,
         'op': self.op,
         'in_ports_count': 2,
         'out_ports_count': 1,
         'infer': self.infer,
         'reverse_infer': lambda node: reverse_bypass_infer(node, in_ports=[0]),
     }
     super().__init__(graph, mandatory_props, attrs)
Ejemplo n.º 6
0
 def __init__(self, graph: Graph, attrs: dict):
     mandatory_props = {
         'type': __class__.op,
         'op': __class__.op,
         'version': 'opset1',
         'in_ports_count': 1,
         'out_ports_count': 1,
         'infer': copy_shape_infer,
         'reverse_infer':
         lambda node: reverse_bypass_infer(node, in_ports=[0]),
     }
     super().__init__(graph, mandatory_props, attrs)
Ejemplo n.º 7
0
def tf_fused_bn_extractor(pb):
    is_training = pb.attr['is_training'].b
    if is_training:
        log.warning('FusedBatchNorm doesn\'t support is_training=True')

    return {
        'data_format': pb.attr["data_format"].s,
        'data_type': tf_dtype_extractor(pb.attr["T"].type),
        'eps': pb.attr['epsilon'].f,
        'infer': tf_fused_bn_infer,
        'reverse_infer': lambda node: reverse_bypass_infer(node, in_ports=[0]),
        'is_training': is_training
    }
Ejemplo n.º 8
0
 def __init__(self, graph: Graph, attrs: dict):
     mandatory_props = {
         'op': self.op,
         'type': 'Convert',
         'version': 'opset1',
         'infer': self.infer,
         'reverse_infer':
         lambda node: reverse_bypass_infer(node, in_ports=[0]),
         'type_infer': self.type_infer,
         'dst_type': None,
         'in_ports_count': 1,
         'out_ports_count': 1,
     }
     super().__init__(graph, mandatory_props, attrs)
Ejemplo n.º 9
0
    def __init__(self, graph: Graph, attrs: dict):
        assert self.op is not None and self.op_type is not None and self.version is not None, \
            'Please use specialized Scatter operation class, Scatter is base class'

        mandatory_props = {
            'op': self.op,
            'type': self.op_type,
            'version': self.version,

            'is_scatter': True,  # is used for gathering all types of scatters in common transformations
            'infer': self.infer,
            'reverse_infer': lambda node: reverse_bypass_infer(node, in_ports=[0]),

            'in_ports_count': 4,
            'out_ports_count': 1,
        }
        super().__init__(graph, mandatory_props, attrs)
Ejemplo n.º 10
0
 def __init__(self, graph: Graph, attrs: dict):
     super().__init__(
         graph, {
             'type':
             self.op,
             'op':
             self.op,
             'version':
             'opset7',
             'infer':
             roll_infer,
             'reverse_infer':
             lambda node: reverse_bypass_infer(node, in_ports=[0]),
             'in_ports_count':
             3,
             'out_ports_count':
             1
         }, attrs)