Ejemplo n.º 1
0
def conv_set_params(conv_param, conv_type):
    # Defaults
    padding = [0, 0]
    stride = [1, 1]
    kernel = [0, 0]
    dilate = [1, 1]
    group = 1

    kernel = get_spatial_attr(kernel, 'kernel_size', 'kernel', conv_param)
    padding = get_spatial_attr(padding, 'pad', 'pad', conv_param)
    stride = get_spatial_attr(stride, 'stride', 'stride', conv_param)
    dilates = get_list_from_container(conv_param, 'dilation', int)
    if len(dilates) > 0:
        dilate[0] = dilate[1] = dilates[0]

    groups = get_list_from_container(conv_param, 'group', int)
    group = groups[0] if len(groups) > 0 and groups[0] != 1 else group

    return {
        'type_str': conv_type,
        'padding': padding,
        'dilate': dilate,
        'stride': stride,
        'kernel': kernel,
        'group': group,
        'output': conv_param.num_output,
        'bias_term': conv_param.bias_term
    }
Ejemplo n.º 2
0
 def test_get_list_from_container_list_match_empty(self):
     res = get_list_from_container(FakeParam('prop', []), 'prop', int)
     self.assertEqual(res, [])
Ejemplo n.º 3
0
 def test_get_list_from_container_no_param(self):
     res = get_list_from_container(None, 'prop', int)
     self.assertEqual(res, [])
Ejemplo n.º 4
0
 def test_get_list_from_container_simple_type_match(self):
     res = get_list_from_container(FakeParam('prop', 10), 'prop', int)
     self.assertEqual(res, [10])
Ejemplo n.º 5
0
 def test_get_list_from_container_no_existing_param(self):
     res = get_list_from_container(FakeParam("p", "1"), 'prop', int)
     self.assertEqual(res, [])