Exemplo n.º 1
0
def conv_set_params(conv_param, conv_type):
    # Defaults
    padding = [0, 0]
    stride = [1, 1]
    kernel = [0, 0]
    dilate = [1, 1]
    group = 1

    kernel = get_spatial_attr(kernel, 'kernel_size', 'kernel', conv_param)
    padding = get_spatial_attr(padding, 'pad', 'pad', conv_param)
    stride = get_spatial_attr(stride, 'stride', 'stride', conv_param)
    dilates = get_list_from_container(conv_param, 'dilation', int)
    if len(dilates) > 0:
        dilate[0] = dilate[1] = dilates[0]

    groups = get_list_from_container(conv_param, 'group', int)
    group = groups[0] if len(groups) > 0 and groups[0] != 1 else group

    return {
        'type_str': conv_type,
        'padding': padding,
        'dilate': dilate,
        'stride': stride,
        'kernel': kernel,
        'group': group,
        'output': conv_param.num_output,
        'bias_term': conv_param.bias_term
    }
Exemplo n.º 2
0
    def extract(cls, node):
        proto_layer = node.pb
        param = proto_layer.pooling_param

        method = 'max'
        exclude_pad = True
        kernel = [0, 0]
        stride = [1, 1]
        padding = [0, 0]
        global_pooling = False

        if hasattr(param, 'global_pooling') and param.global_pooling:
            global_pooling = param.global_pooling
        else:
            kernel = get_spatial_attr(kernel, 'kernel_size', 'kernel', param)
            padding = get_spatial_attr(padding, 'pad', 'pad', param)
            stride = get_spatial_attr(stride, 'stride', 'stride', param)

        if param.pool == 0:
            method = 'max'
            exclude_pad = True
        elif param.pool == 1:
            method = 'avg'
            exclude_pad = False
        else:
            raise ValueError('Unknown Pooling Method!')

        pooling_convention = 'full'  # for Caffe rounding type should be ceil
        rt = 'ceil'

        if hasattr(param, 'ceil_mode') and not param.ceil_mode:
            # If pooling has ceil_mode and ceil_mode is False using floor for rounding shapes in partial_infer
            pooling_convention = 'valid'
            rt = 'floor'

        attrs = {
            'window': int64_array([1, 1, kernel[1], kernel[0]]),
            'stride': int64_array([1, 1, stride[1], stride[0]]),
            'pad': int64_array([[0, 0], [0, 0], [padding[1], padding[1]], [padding[0], padding[0]]]),
            'pad_spatial_shape': int64_array([[padding[1], padding[1]], [padding[0], padding[0]]]),
            'pool_method': method,
            'exclude_pad': exclude_pad,
            'global_pool': global_pooling,
            'output_spatial_shape': None,
            'rounding_type': rt
        }

        attrs.update(layout_attrs())
        attrs['pooling_convention'] = pooling_convention

        # update the attributes of the node
        Pooling.update_node_stat(node, attrs)
        return cls.enabled