def conv_set_params(conv_param, conv_type): # Defaults padding = [0, 0] stride = [1, 1] kernel = [0, 0] dilate = [1, 1] group = 1 kernel = get_spatial_attr(kernel, 'kernel_size', 'kernel', conv_param) padding = get_spatial_attr(padding, 'pad', 'pad', conv_param) stride = get_spatial_attr(stride, 'stride', 'stride', conv_param) dilates = get_list_from_container(conv_param, 'dilation', int) if len(dilates) > 0: dilate[0] = dilate[1] = dilates[0] groups = get_list_from_container(conv_param, 'group', int) group = groups[0] if len(groups) > 0 and groups[0] != 1 else group return { 'type_str': conv_type, 'padding': padding, 'dilate': dilate, 'stride': stride, 'kernel': kernel, 'group': group, 'output': conv_param.num_output, 'bias_term': conv_param.bias_term }
def test_get_list_from_container_list_match_empty(self): res = get_list_from_container(FakeParam('prop', []), 'prop', int) self.assertEqual(res, [])
def test_get_list_from_container_no_param(self): res = get_list_from_container(None, 'prop', int) self.assertEqual(res, [])
def test_get_list_from_container_simple_type_match(self): res = get_list_from_container(FakeParam('prop', 10), 'prop', int) self.assertEqual(res, [10])
def test_get_list_from_container_no_existing_param(self): res = get_list_from_container(FakeParam("p", "1"), 'prop', int) self.assertEqual(res, [])