예제 #1
0
def common_pool_extender(op: Node):
    for attr in ['strides', 'pads_begin', 'pads_end', 'kernel']:
        Extender.attr_to_list(op, attr)
    op['stride'] = int64_array([1, 1] + op.strides)
    op['window'] = int64_array([1, 1] + op.kernel)
    op['kernel_spatial'] = op.kernel
    op['output_spatial_shape'] = None

    op['batch_dims'] = int64_array([0]),
    op['channel_dims'] = int64_array([1]),

    dim = len(op.pads_begin)

    assert dim in (1, 2, 3), '{}D {} not supported! Node name: {}'.format(
        dim, op.soft_get('type'), op.soft_get('name', op.id))

    pad = [[0, 0], [0, 0]]
    pad.extend([[op.pads_begin[i], op.pads_end[i]] for i in range(dim)])

    op['pad'] = int64_array(pad)

    op['spatial_dims'] = [i + 2 for i in range(dim)]

    if op.has_valid('rounding_type') and op.rounding_type == 'ceil':
        op['pooling_convention'] = 'full'
    def extend(op: Node):
        for attr in [
                'strides', 'dilations', 'pads_begin', 'pads_end',
                'output_padding'
        ]:
            Extender.attr_to_list(op, attr)

        op['stride'] = int64_array([1, 1] + op.strides)
        op['dilation'] = int64_array([1, 1] + op.dilations)

        op['batch_dims'] = int64_array([0])
        op['channel_dims'] = int64_array([1])

        # Be VERY careful with these attributes!
        op['input_feature_channel'] = 1
        op['output_feature_channel'] = 0

        dim = len(op.pads_begin)

        assert dim in (1, 2, 3), '{}D Convolution not supported!'.format(dim)

        pad = [[0, 0], [0, 0]]
        pad.extend([[op.pads_begin[i], op.pads_end[i]] for i in range(dim)])

        op['pad'] = int64_array(pad)

        op['spatial_dims'] = [i + 2 for i in range(dim)]
예제 #3
0
def common_backpropdata_extender(op: Node):
    for attr in [
            'strides', 'output_padding', 'pads_begin', 'pads_end', 'dilations'
    ]:
        Extender.attr_to_list(op, attr)

    if op.has_valid('output_padding'):
        op.output_padding = int64_array([0, 0] + op.output_padding)

    dim = len(op.strides)

    if op.has_valid('pads_begin') and op.has_valid('pads_end'):
        pad = [[0, 0], [0, 0]]
        pad.extend([[op.pads_begin[i], op.pads_end[i]] for i in range(dim)])

        op['pad'] = int64_array(pad)

    op['spatial_dims'] = [i + 2 for i in range(dim)]

    if not op.has_valid('dilations'):
        op['dilations'] = [1 for _ in range(dim)]
    if not op.has_valid('strides'):
        op['strides'] = [1 for _ in range(dim)]

    op['dilation'] = int64_array([1, 1] + op.dilations)
    op['stride'] = int64_array([1, 1] + op.strides)

    op['infer'] = backpropdata_infer
예제 #4
0
 def attr_restore(node: Node, attribute: str, value=None):
     # Function to restore some specific attr for PriorBox & PriorBoxClustered layers
     if not node.has_valid(attribute):
         node[attribute] = [] if value is None else [value]
     if isinstance(node[attribute], str):
         node[attribute] = []
     else:
         Extender.attr_to_list(node, attribute)
예제 #5
0
 def extend(op: Node):
     assert op.has_valid(
         'element_type'
     ), 'Parameter node {} has missed element_type attr!'.format(op.name)
     op['data_type'] = destination_type_to_np_data_type(op.element_type)
     if op.shape == '':
         op.shape = int64_array([])
     else:
         Extender.attr_to_list(op, 'shape')
예제 #6
0
    def extend(op: Node):

        attrs = [
            'shrink_axis_mask', 'new_axis_mask', 'ellipsis_mask', 'begin_mask',
            'end_mask'
        ]
        for attr in attrs:
            Extender.attr_to_list(op, attr)

        op.begin_mask = int64_array([1 - i for i in op.begin_mask])
        op.end_mask = int64_array([1 - i for i in op.end_mask])
    def extend(op: Node):
        for attr in StridedSlice.get_mask_names():
            # We can not use op.has_and_set(attr) here as a condition, because it will return False if begin/end is
            # 1D tensor and begin_mask/end_mask is equal to 0
            if op.has(attr) and op[attr] != '':
                Extender.attr_to_list(op, attr)
            else:
                assert attr not in ['begin_mask', 'end_mask'],\
                    '{} is not defined for the node {}'.format(attr, op.soft_get('name', op.id))
                op[attr] = int64_array([0])

        op.begin_mask = int64_array([1 - i for i in op.begin_mask])
        op.end_mask = int64_array([1 - i for i in op.end_mask])
예제 #8
0
    def extend(op: Node):
        for attr in StridedSlice.get_mask_names():
            Extender.attr_to_list(op, attr)

        op.begin_mask = int64_array([1 - i for i in op.begin_mask])
        op.end_mask = int64_array([1 - i for i in op.end_mask])
예제 #9
0
 def extend(op: Node):
     Extender.attr_to_list(op, 'pyramid_scales')
예제 #10
0
 def extend(op: Node):
     Extender.attr_to_list(op, 'axes')