def conv_set_params(conv_param, conv_type): # Defaults padding = [0, 0] stride = [1, 1] kernel = [0, 0] dilate = [1, 1] group = 1 kernel = get_spatial_attr(kernel, 'kernel_size', 'kernel', conv_param) padding = get_spatial_attr(padding, 'pad', 'pad', conv_param) stride = get_spatial_attr(stride, 'stride', 'stride', conv_param) dilates = get_list_from_container(conv_param, 'dilation', int) if len(dilates) > 0: dilate[0] = dilate[1] = dilates[0] groups = get_list_from_container(conv_param, 'group', int) group = groups[0] if len(groups) > 0 and groups[0] != 1 else group return { 'type_str': conv_type, 'padding': padding, 'dilate': dilate, 'stride': stride, 'kernel': kernel, 'group': group, 'output': conv_param.num_output, 'bias_term': conv_param.bias_term }
def extract(cls, node): proto_layer = node.pb param = proto_layer.pooling_param method = 'max' exclude_pad = True kernel = [0, 0] stride = [1, 1] padding = [0, 0] global_pooling = False if hasattr(param, 'global_pooling') and param.global_pooling: global_pooling = param.global_pooling else: kernel = get_spatial_attr(kernel, 'kernel_size', 'kernel', param) padding = get_spatial_attr(padding, 'pad', 'pad', param) stride = get_spatial_attr(stride, 'stride', 'stride', param) if param.pool == 0: method = 'max' exclude_pad = True elif param.pool == 1: method = 'avg' exclude_pad = False else: raise ValueError('Unknown Pooling Method!') pooling_convention = 'full' # for Caffe rounding type should be ceil rt = 'ceil' if hasattr(param, 'ceil_mode') and not param.ceil_mode: # If pooling has ceil_mode and ceil_mode is False using floor for rounding shapes in partial_infer pooling_convention = 'valid' rt = 'floor' attrs = { 'window': np.array([1, 1, kernel[1], kernel[0]], dtype=np.int64), 'stride': np.array([1, 1, stride[1], stride[0]], dtype=np.int64), 'pad': np.array([[0, 0], [0, 0], [padding[1], padding[1]], [padding[0], padding[0]]], dtype=np.int64), 'pad_spatial_shape': np.array([[padding[1], padding[1]], [padding[0], padding[0]]], dtype=np.int64), 'pool_method': method, 'exclude_pad': exclude_pad, 'global_pool': global_pooling, 'output_spatial_shape': None, 'rounding_type': rt } attrs.update(layout_attrs()) attrs['pooling_convention'] = pooling_convention # update the attributes of the node Pooling.update_node_stat(node, attrs) return cls.enabled