def extract(node): node['batch_dims'] = 0 node['channel_dims'] = 3 node['spatial_dims'] = [1, 2] ReorgYoloOp.update_node_stat( node, {'stride': np.array(node.pb.attr['strides'].list.i[1])}) return __class__.enabled
def extract(cls, node): proto_layer = node.pb param = proto_layer.reorg_yolo_param stride = param.stride update_attrs = { 'stride': stride, } mapping_rule = merge_attrs(param, update_attrs) mapping_rule.update(layout_attrs()) # update the attributes of the node ReorgYoloOp.update_node_stat(node, mapping_rule) return cls.enabled
def test_reorgyolo_infer(self): graph = build_graph(nodes_attributes, [('node_1', 'reorg'), ('reorg', 'node_3'), ('node_3', 'op_output') ], {'node_3': {'shape': None, 'value': None}, 'node_1': {'shape': np.array([1, 3, 227, 227]), 'value': None}, 'reorg': {'stride': 2, **layout_attrs()} }) reorg_node = Node(graph, 'reorg') ReorgYoloOp.reorgyolo_infer(reorg_node) exp_shape = calculate_reorgyolo_output(np.array([1, 3, 227, 227]), 2) res_shape = graph.node['node_3']['shape'] for i in range(0, len(exp_shape)): self.assertEqual(exp_shape[i], res_shape[i])