示例#1
0
    def test_params_creation(self):
        params = {
            'pad': None,
            'kernel_size': None,
            'stride': None,
            'dilation': None,
            'group': 14,
            'num_output': 15,
            'bias_term': True,
            'pad_w': 3,
            'pad_h': 4,
            'kernel_w': 5,
            'kernel_h': 6,
            'stride_h': 3,
            'stride_w': 2,
        }
        exp_res = {
            'padding': [3, 4],
            'stride': [2, 3],
            'kernel': [5, 6],
            'dilate': [1, 1],
            'group': 14,
            'output': 15
        }
        res = conv_set_params(
            FakeConvProtoLayer(FakeMultiParam(params)).convolution_param,
            'Conv2D')

        for key in exp_res.keys():
            if key in ('padding', 'stride', 'stride', 'kernel', 'dilate'):
                np.testing.assert_equal(res[key], exp_res[key])
            else:
                self.assertEqual(res[key], exp_res[key])
示例#2
0
    def extract(node):
        proto_layer, model_layer = node.pb, node.model_pb

        if not proto_layer:
            raise Error('Protobuf layer can not be empty')

        conv_param = proto_layer.convolution_param
        conv_type = 'ConvND' if len(proto_layer.bottom) > 1 else 'Conv2D'

        params = conv_set_params(conv_param, conv_type)
        attrs = conv_create_attrs(params)
        attrs.update({
            'op': __class__.op,
            'get_group': lambda node: node.group,
            'get_output_feature_dim': lambda node: node.output,
            'weights_index': 1 if conv_type == 'Conv2D' else 2
        })

        # Embed weights and biases as attributes
        # It will be moved to a separate nodes in special pass
        attrs.update(
            weights_biases(conv_param.bias_term,
                           model_layer,
                           start_index=len(proto_layer.bottom),
                           proto=conv_param))
        attrs.update(layout_attrs())

        # update the attributes of the node
        Convolution.update_node_stat(node, attrs)
        return __class__.enabled