Exemplo n.º 1
0
    def test_tf_concat_infer(self):
        graph = build_graph(nodes_attributes,
                            [
                                ('node_1', 'corr'),
                                ('node_2', 'corr'),
                                ('corr', 'node_3'),
                                ('node_3', 'op_output')
                            ],
                            {
                                'node_3': {'shape': None},
                                'node_1': {'shape': np.array([1, 3, 227, 227])},
                                'node_2': {'shape': np.array([1, 3, 227, 227])},
                                'corr': {'pad': 20,
                                         'kernel_size': 1,
                                         'max_displacement': 20,
                                         'stride_1': 1,
                                         'stride_2': 2,
                                         'single_direction': 0,
                                         'do_abs': False,
                                         'correlation_type': 0}
                            })

        corr_node = Node(graph, 'corr')
        CorrelationOp.corr_infer(corr_node)
        exp_shape = np.array([1, 441, 227, 227])
        res_shape = graph.node['node_3']['shape']
        for i in range(0, len(exp_shape)):
            self.assertEqual(exp_shape[i], res_shape[i])
Exemplo n.º 2
0
    def extract(cls, node):
        proto_layer = node.pb
        param = proto_layer.correlation_param

        corr_type = 'caffe.CorrelationParameter.MULTIPLY'
        if param.correlation_type == 1:
            corr_type = 'caffe.CorrelationParameter.SUBTRACT'

        update_attrs = {
            'pad': param.pad,
            'kernel_size': param.kernel_size,
            'max_displacement': param.max_displacement,
            'stride_1': param.stride_1,
            'stride_2': param.stride_2,
            'single_direction': param.single_direction,
            'do_abs': int(param.do_abs),
            'correlation_type': corr_type,
        }

        mapping_rule = merge_attrs(param, update_attrs)

        mapping_rule.update(layout_attrs())

        # update the attributes of the node
        CorrelationOp.update_node_stat(node, mapping_rule)
        return cls.enabled